Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PermuteChannels transform #7624

Merged
merged 11 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ Color

ColorJitter
v2.ColorJitter
v2.RandomChannelPermutation
v2.RandomPhotometricDistort
Grayscale
v2.Grayscale
Expand Down
1 change: 1 addition & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class TestSmoke:
(transforms.RandomEqualize(p=1.0), None),
(transforms.RandomGrayscale(p=1.0), None),
(transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None),
(transforms.RandomPhotometricDistort(p=1.0), None),
(transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
Expand Down
58 changes: 58 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2280,3 +2280,61 @@ def resize_my_datapoint():
_register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint)

assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint


class TestPermuteChannels:
_DEFAULT_PERMUTATION = [2, 0, 1]

@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.permute_channels_image_tensor, make_image_tensor),
# FIXME
# check_kernel does not support PIL kernel, but it should
(F.permute_channels_image_tensor, make_image),
(F.permute_channels_video, make_video),
],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel(self, kernel, make_input, dtype, device):
check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION)

@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.permute_channels_image_tensor, make_image_tensor),
(F.permute_channels_image_pil, make_image_pil),
(F.permute_channels_image_tensor, make_image),
(F.permute_channels_video, make_video),
],
)
def test_dispatcher(self, kernel, make_input):
check_dispatcher(F.permute_channels, kernel, make_input(), permutation=self._DEFAULT_PERMUTATION)

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.permute_channels_image_tensor, torch.Tensor),
(F.permute_channels_image_pil, PIL.Image.Image),
(F.permute_channels_image_tensor, datapoints.Image),
(F.permute_channels_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type)

def reference_image_correctness(self, image, permutation):
channel_images = image.split(1, dim=-3)
permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation]
return datapoints.Image(torch.concat(permuted_channel_images, dim=-3))

@pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]])
@pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)])
def test_image_correctness(self, permutation, batch_dims):
image = make_image(batch_dims=batch_dims)

actual = F.permute_channels(image, permutation=permutation)
expected = self.reference_image_correctness(image, permutation=permutation)
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved

torch.testing.assert_close(actual, expected)
1 change: 1 addition & 0 deletions torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Grayscale,
RandomAdjustSharpness,
RandomAutocontrast,
RandomChannelPermutation,
RandomEqualize,
RandomGrayscale,
RandomInvert,
Expand Down
39 changes: 22 additions & 17 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,27 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return output


# TODO: This class seems to be untested
class RandomChannelPermutation(Transform):
"""[BETA] Randomly permute the channels of an image or video

.. v2betastatus:: RandomChannelPermutation transform
"""

_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(permutation=torch.randperm(num_channels))

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.permute_channels(inpt, params["permutation"])


class RandomPhotometricDistort(Transform):
"""[BETA] Randomly distorts the image or video as used in `SSD: Single Shot
MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
Expand Down Expand Up @@ -241,21 +261,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params

def _permute_channels(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
) -> Union[datapoints._ImageType, datapoints._VideoType]:
orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)

# TODO: Find a better fix than as_subclass???
output = inpt[..., permutation, :, :].as_subclass(type(inpt))

if isinstance(orig_inpt, PIL.Image.Image):
output = F.to_image_pil(output)

return output

def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
Expand All @@ -270,7 +275,7 @@ def _transform(
if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
inpt = F.permute_channels(inpt, permutation=params["channel_permutation"])
return inpt


Expand Down
4 changes: 4 additions & 0 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
invert_image_pil,
invert_image_tensor,
invert_video,
permute_channels,
permute_channels_image_pil,
permute_channels_image_tensor,
permute_channels_video,
posterize,
posterize_image_pil,
posterize_image_tensor,
Expand Down
65 changes: 64 additions & 1 deletion torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import List, Union

import PIL.Image
import torch
Expand All @@ -10,6 +10,8 @@
from torchvision.utils import _log_api_usage_once

from ._misc import _num_value_bits, to_dtype_image_tensor

from ._type_conversion import pil_to_tensor, to_image_pil
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal


Expand Down Expand Up @@ -641,3 +643,64 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(invert, datapoints.Video)
def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video)


@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT:
pmeier marked this conversation as resolved.
Show resolved Hide resolved
"""Permute the channels of the input according to the given permutation.

This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
:class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`.

Example:
>>> rgb_image = torch.rand(3, 256, 256)
>>> bgr_image = F.permutate_channels(rgb_image, permutation=[2, 1, 0])

Args:
permutation (List[int]): Valid permutation of the input channel indices. The index of the element determines the
channel index in the input and the value determines the channel index in the output. For example,
``permutation=[2, 0 , 1]``

- takes ``ìnpt[..., 0, :, :]`` and puts it at ``output[..., 2, :, :]``,
- takes ``ìnpt[..., 1, :, :]`` and puts it at ``output[..., 0, :, :]``, and
- takes ``ìnpt[..., 2, :, :]`` and puts it at ``output[..., 1, :, :]``.

Raises:
ValueError: If ``len(permutation)`` doesn't match the number of channels in the input.
"""
if torch.jit.is_scripting():
return permute_channels_image_tensor(inpt, permutation=permutation)

_log_api_usage_once(permute_channels)

kernel = _get_kernel(permute_channels, type(inpt))
return kernel(inpt, permutation=permutation)


@_register_kernel_internal(permute_channels, torch.Tensor)
@_register_kernel_internal(permute_channels, datapoints.Image)
def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -> torch.Tensor:
shape = image.shape
num_channels, height, width = shape[-3:]

if len(permutation) != num_channels:
raise ValueError(
f"Length of permutation does not match number of channels: " f"{len(permutation)} != {num_channels}"
)

if image.numel() == 0:
return image

image = image.reshape(-1, num_channels, height, width)
pmeier marked this conversation as resolved.
Show resolved Hide resolved
image = image[:, permutation, :, :]
return image.reshape(shape)


@_register_kernel_internal(permute_channels, PIL.Image.Image)
def permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image:
return to_image_pil(permute_channels_image_tensor(pil_to_tensor(image), permutation=permutation))


@_register_kernel_internal(permute_channels, datapoints.Video)
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
return permute_channels_image_tensor(video, permutation=permutation)