From b9aaf4330787e715ced25338496563b96ea82a20 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 19 Jan 2023 16:44:17 +0000 Subject: [PATCH 1/4] Let Normalize() and RandomPhotometricDistort return datapoints instead of tensors --- test/test_prototype_transforms_functional.py | 4 +-- torchvision/prototype/datapoints/_image.py | 4 +++ torchvision/prototype/datapoints/_video.py | 4 +++ torchvision/prototype/transforms/_color.py | 10 +++--- .../prototype/transforms/functional/_misc.py | 31 ++++++++++++------- 5 files changed, 34 insertions(+), 19 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index d199625df0f..b9f68481d54 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1185,15 +1185,15 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") +# TODO: I guess we need to change the name of this test. Should we have a +# _correctness test as well like the rest? def test_normalize_output_type(): inpt = torch.rand(1, 3, 32, 32) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - assert type(output) is torch.Tensor torch.testing.assert_close(inpt - 0.5, output) inpt = make_image(color_space=datapoints.ColorSpace.RGB) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - assert type(output) is torch.Tensor torch.testing.assert_close(inpt - 0.5, output) diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index fc20691100f..a6a8c928334 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -289,6 +289,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N ) return Image.wrap_like(self, output) + def normalize(self, mean: List[float], std: List[float], inplace: bool = False): + output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) + return Image.wrap_like(self, output) + ImageType = Union[torch.Tensor, PIL.Image.Image, Image] ImageTypeJIT = torch.Tensor diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 5c55d23a149..af58a5d2c1f 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -241,6 +241,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma) return Video.wrap_like(self, output) + def normalize(self, mean: List[float], std: List[float], inplace: bool = False): + output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) + return Video.wrap_like(self, output) + VideoType = Union[torch.Tensor, Video] VideoTypeJIT = torch.Tensor diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 0254dd7c225..4109ccfccf7 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -82,6 +82,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output +# TODO: Are there tests for this class? class RandomPhotometricDistort(Transform): _transformed_types = ( datapoints.Image, @@ -119,15 +120,14 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _permute_channels( self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor ) -> Union[datapoints.ImageType, datapoints.VideoType]: - if isinstance(inpt, PIL.Image.Image): + + orig_inpt = inpt + if isinstance(orig_inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) output = inpt[..., permutation, :, :] - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type] - - elif isinstance(inpt, PIL.Image.Image): + if isinstance(orig_inpt, PIL.Image.Image): output = F.to_image_pil(output) return output diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 59570768160..0a0ede90940 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -58,20 +58,27 @@ def normalize( std: List[float], inplace: bool = False, ) -> torch.Tensor: + # if torch.jit.is_scripting() or is_simple_tensor(inpt): + # return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) + # elif isinstance(inpt, datapoints._datapoint.Datapoint): + # return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) + # elif isinstance(inpt, PIL.Image.Image): + # return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) + # else: + # raise TypeError( + # f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + # f"but got {type(inpt)} instead." + # ) if not torch.jit.is_scripting(): _log_api_usage_once(normalize) - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - inpt = inpt.as_subclass(torch.Tensor) - elif not is_simple_tensor(inpt): - raise TypeError( - f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " - f"but got {type(inpt)} instead." - ) - - # Image or Video type should not be retained after normalization due to unknown data range - # Thus we return Tensor for input Image - return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + elif isinstance(inpt, (datapoints.Image, datapoints.Video)): + return inpt.normalize(mean=mean, std=std, inplace=inplace) + else: + raise TypeError( + f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." + ) def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: From 5a3c51ac119fa547ac03233f326f130eaade45ef Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 19 Jan 2023 16:45:46 +0000 Subject: [PATCH 2/4] cleanup --- torchvision/prototype/transforms/functional/_misc.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 0a0ede90940..9d0a00f88c3 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -58,17 +58,6 @@ def normalize( std: List[float], inplace: bool = False, ) -> torch.Tensor: - # if torch.jit.is_scripting() or is_simple_tensor(inpt): - # return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) - # elif isinstance(inpt, datapoints._datapoint.Datapoint): - # return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) - # elif isinstance(inpt, PIL.Image.Image): - # return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) - # else: - # raise TypeError( - # f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - # f"but got {type(inpt)} instead." - # ) if not torch.jit.is_scripting(): _log_api_usage_once(normalize) if torch.jit.is_scripting() or is_simple_tensor(inpt): From cb4c0f4996395cc0b75c8f84e55860b9fc2635da Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jan 2023 08:04:12 +0000 Subject: [PATCH 3/4] Address comments --- test/prototype_transforms_dispatcher_infos.py | 1 - test/test_prototype_transforms_functional.py | 12 ------------ torchvision/prototype/transforms/_color.py | 2 +- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index b92278fef56..90e2e7f570f 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -426,7 +426,6 @@ def fill_sequence_needs_broadcast(args_kwargs): datapoints.Video: F.normalize_video, }, test_marks=[ - skip_dispatch_feature, xfail_jit_python_scalar_arg("mean"), xfail_jit_python_scalar_arg("std"), ], diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index b9f68481d54..0b3570250a6 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1185,18 +1185,6 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") -# TODO: I guess we need to change the name of this test. Should we have a -# _correctness test as well like the rest? -def test_normalize_output_type(): - inpt = torch.rand(1, 3, 32, 32) - output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - torch.testing.assert_close(inpt - 0.5, output) - - inpt = make_image(color_space=datapoints.ColorSpace.RGB) - output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - torch.testing.assert_close(inpt - 0.5, output) - - @pytest.mark.parametrize( "inpt", [ diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 4109ccfccf7..0eb20e57764 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -82,7 +82,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output -# TODO: Are there tests for this class? +# TODO: This class seems to be untested class RandomPhotometricDistort(Transform): _transformed_types = ( datapoints.Image, From ac94d7259a564d5cd52777811e1677ec18bf7a32 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Jan 2023 09:00:20 +0000 Subject: [PATCH 4/4] lint + types --- test/test_prototype_transforms_functional.py | 2 +- torchvision/prototype/datapoints/_image.py | 2 +- torchvision/prototype/datapoints/_video.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 0b3570250a6..a80e0f4570d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -13,7 +13,7 @@ import torchvision.prototype.transforms.utils from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed -from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message +from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS from torch.utils._pytree import tree_map diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index a6a8c928334..d674745a716 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -289,7 +289,7 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N ) return Image.wrap_like(self, output) - def normalize(self, mean: List[float], std: List[float], inplace: bool = False): + def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image: output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) return Image.wrap_like(self, output) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index af58a5d2c1f..c7273874655 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -241,7 +241,7 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma) return Video.wrap_like(self, output) - def normalize(self, mean: List[float], std: List[float], inplace: bool = False): + def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video: output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) return Video.wrap_like(self, output)