diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a1858c6b514..0df46c92530 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -155,6 +155,7 @@ Color ColorJitter v2.ColorJitter + v2.RandomChannelPermutation v2.RandomPhotometricDistort Grayscale v2.Grayscale diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 353cc846bed..5f4a9b62898 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -124,6 +124,7 @@ class TestSmoke: (transforms.RandomEqualize(p=1.0), None), (transforms.RandomGrayscale(p=1.0), None), (transforms.RandomInvert(p=1.0), None), + (transforms.RandomChannelPermutation(), None), (transforms.RandomPhotometricDistort(p=1.0), None), (transforms.RandomPosterize(bits=4, p=1.0), None), (transforms.RandomSolarize(threshold=0.5, p=1.0), None), diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c910882f9fd..fa04d5deb0c 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2280,3 +2280,61 @@ def resize_my_datapoint(): _register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint) assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint + + +class TestPermuteChannels: + _DEFAULT_PERMUTATION = [2, 0, 1] + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.permute_channels_image_tensor, make_image_tensor), + # FIXME + # check_kernel does not support PIL kernel, but it should + (F.permute_channels_image_tensor, make_image), + (F.permute_channels_video, make_video), + ], + ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel(self, kernel, make_input, dtype, device): + check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION) + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.permute_channels_image_tensor, make_image_tensor), + (F.permute_channels_image_pil, make_image_pil), + (F.permute_channels_image_tensor, make_image), + (F.permute_channels_video, make_video), + ], + ) + def test_dispatcher(self, kernel, make_input): + check_dispatcher(F.permute_channels, kernel, make_input(), permutation=self._DEFAULT_PERMUTATION) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.permute_channels_image_tensor, torch.Tensor), + (F.permute_channels_image_pil, PIL.Image.Image), + (F.permute_channels_image_tensor, datapoints.Image), + (F.permute_channels_video, datapoints.Video), + ], + ) + def test_dispatcher_signature(self, kernel, input_type): + check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type) + + def reference_image_correctness(self, image, permutation): + channel_images = image.split(1, dim=-3) + permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation] + return datapoints.Image(torch.concat(permuted_channel_images, dim=-3)) + + @pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]]) + @pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)]) + def test_image_correctness(self, permutation, batch_dims): + image = make_image(batch_dims=batch_dims) + + actual = F.permute_channels(image, permutation=permutation) + expected = self.reference_image_correctness(image, permutation=permutation) + + torch.testing.assert_close(actual, expected) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 8ce9bee9b4d..4451cb7a1a2 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -11,6 +11,7 @@ Grayscale, RandomAdjustSharpness, RandomAutocontrast, + RandomChannelPermutation, RandomEqualize, RandomGrayscale, RandomInvert, diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 7dd8eeae236..8315e2f36b4 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -177,7 +177,27 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output -# TODO: This class seems to be untested +class RandomChannelPermutation(Transform): + """[BETA] Randomly permute the channels of an image or video + + .. v2betastatus:: RandomChannelPermutation transform + """ + + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_channels, *_ = query_chw(flat_inputs) + return dict(permutation=torch.randperm(num_channels)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.permute_channels(inpt, params["permutation"]) + + class RandomPhotometricDistort(Transform): """[BETA] Randomly distorts the image or video as used in `SSD: Single Shot MultiBox Detector `_. @@ -241,21 +261,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None return params - def _permute_channels( - self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor - ) -> Union[datapoints._ImageType, datapoints._VideoType]: - orig_inpt = inpt - if isinstance(orig_inpt, PIL.Image.Image): - inpt = F.pil_to_tensor(inpt) - - # TODO: Find a better fix than as_subclass??? - output = inpt[..., permutation, :, :].as_subclass(type(inpt)) - - if isinstance(orig_inpt, PIL.Image.Image): - output = F.to_image_pil(output) - - return output - def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: @@ -270,7 +275,7 @@ def _transform( if params["contrast_factor"] is not None and not params["contrast_before"]: inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) if params["channel_permutation"] is not None: - inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) + inpt = F.permute_channels(inpt, permutation=params["channel_permutation"]) return inpt diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 163a55fad38..f3295860155 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -62,6 +62,10 @@ invert_image_pil, invert_image_tensor, invert_video, + permute_channels, + permute_channels_image_pil, + permute_channels_image_tensor, + permute_channels_video, posterize, posterize_image_pil, posterize_image_tensor, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 71797fd2500..9b6bf3886fa 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Union import PIL.Image import torch @@ -10,6 +10,8 @@ from torchvision.utils import _log_api_usage_once from ._misc import _num_value_bits, to_dtype_image_tensor + +from ._type_conversion import pil_to_tensor, to_image_pil from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal @@ -641,3 +643,64 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(invert, datapoints.Video) def invert_video(video: torch.Tensor) -> torch.Tensor: return invert_image_tensor(video) + + +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT: + """Permute the channels of the input according to the given permutation. + + This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and + :class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`. + + Example: + >>> rgb_image = torch.rand(3, 256, 256) + >>> bgr_image = F.permutate_channels(rgb_image, permutation=[2, 1, 0]) + + Args: + permutation (List[int]): Valid permutation of the input channel indices. The index of the element determines the + channel index in the input and the value determines the channel index in the output. For example, + ``permutation=[2, 0 , 1]`` + + - takes ``ìnpt[..., 0, :, :]`` and puts it at ``output[..., 2, :, :]``, + - takes ``ìnpt[..., 1, :, :]`` and puts it at ``output[..., 0, :, :]``, and + - takes ``ìnpt[..., 2, :, :]`` and puts it at ``output[..., 1, :, :]``. + + Raises: + ValueError: If ``len(permutation)`` doesn't match the number of channels in the input. + """ + if torch.jit.is_scripting(): + return permute_channels_image_tensor(inpt, permutation=permutation) + + _log_api_usage_once(permute_channels) + + kernel = _get_kernel(permute_channels, type(inpt)) + return kernel(inpt, permutation=permutation) + + +@_register_kernel_internal(permute_channels, torch.Tensor) +@_register_kernel_internal(permute_channels, datapoints.Image) +def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -> torch.Tensor: + shape = image.shape + num_channels, height, width = shape[-3:] + + if len(permutation) != num_channels: + raise ValueError( + f"Length of permutation does not match number of channels: " f"{len(permutation)} != {num_channels}" + ) + + if image.numel() == 0: + return image + + image = image.reshape(-1, num_channels, height, width) + image = image[:, permutation, :, :] + return image.reshape(shape) + + +@_register_kernel_internal(permute_channels, PIL.Image.Image) +def permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image: + return to_image_pil(permute_channels_image_tensor(pil_to_tensor(image), permutation=permutation)) + + +@_register_kernel_internal(permute_channels, datapoints.Video) +def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: + return permute_channels_image_tensor(video, permutation=permutation)