diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py new file mode 100644 index 00000000000..a9f76fcff07 --- /dev/null +++ b/test/test_transforms_tensor.py @@ -0,0 +1,44 @@ +import torch +from torchvision import transforms as T +from torchvision.transforms import functional as F +from PIL import Image + +import numpy as np + +import unittest + + +class Tester(unittest.TestCase): + def _create_data(self, height=3, width=3, channels=3): + tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) + pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy()) + return tensor, pil_img + + def compareTensorToPIL(self, tensor, pil_image): + pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) + self.assertTrue(tensor.equal(pil_tensor)) + + def _test_flip(self, func, method): + tensor, pil_img = self._create_data() + flip_tensor = getattr(F, func)(tensor) + flip_pil_img = getattr(F, func)(pil_img) + self.compareTensorToPIL(flip_tensor, flip_pil_img) + + scripted_fn = torch.jit.script(getattr(F, func)) + flip_tensor_script = scripted_fn(tensor) + self.assertTrue(flip_tensor.equal(flip_tensor_script)) + + # test for class interface + f = getattr(T, method)() + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + def test_random_horizontal_flip(self): + self._test_flip('hflip', 'RandomHorizontalFlip') + + def test_random_vertical_flip(self): + self._test_flip('vflip', 'RandomVerticalFlip') + + +if __name__ == '__main__': + unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7f22fc51391..22adba6ccc5 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,4 +1,5 @@ import torch +from torch import Tensor import math from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION try: @@ -11,6 +12,9 @@ from collections.abc import Sequence, Iterable import warnings +from . import functional_pil as F_pil +from . import functional_tensor as F_t + def _is_pil_image(img): if accimage is not None: @@ -428,19 +432,22 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE return img -def hflip(img): - """Horizontally flip the given PIL Image. +def hflip(img: Tensor) -> Tensor: + """Horizontally flip the given PIL Image or torch Tensor. Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Torch Tensor): Image to be flipped. If img + is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. Returns: PIL Image: Horizontally flipped image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.hflip(img) - return img.transpose(Image.FLIP_LEFT_RIGHT) + return F_t.hflip(img) def _parse_fill(fill, img, min_pil_version): @@ -530,19 +537,22 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts) -def vflip(img): - """Vertically flip the given PIL Image. +def vflip(img: Tensor) -> Tensor: + """Vertically flip the given PIL Image or torch Tensor. Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Torch Tensor): Image to be flipped. If img + is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. Returns: PIL Image: Vertically flipped image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.vflip(img) - return img.transpose(Image.FLIP_TOP_BOTTOM) + return F_t.vflip(img) def five_crop(img, size): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py new file mode 100644 index 00000000000..e387924ad36 --- /dev/null +++ b/torchvision/transforms/functional_pil.py @@ -0,0 +1,46 @@ +import torch +try: + import accimage +except ImportError: + accimage = None +from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION + + +@torch.jit.unused +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +@torch.jit.unused +def hflip(img): + """Horizontally flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Horizontally flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +@torch.jit.unused +def vflip(img): + """Vertically flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Vertically flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_TOP_BOTTOM) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index b81deed6d43..c0815393c37 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,11 +1,10 @@ import torch -import torchvision.transforms.functional as F from torch import Tensor from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple def _is_tensor_a_torch_image(input): - return len(input.shape) == 3 + return input.ndim >= 2 def vflip(img): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 49fac26e395..5c202f384ee 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -500,25 +500,29 @@ def __repr__(self): return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) -class RandomHorizontalFlip(object): - """Horizontally flip the given PIL Image randomly with a given probability. +class RandomHorizontalFlip(torch.nn.Module): + """Horizontally flip the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): + super().__init__() self.p = p - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Tensor): Image to be flipped. Returns: - PIL Image: Randomly flipped image. + PIL Image or Tensor: Randomly flipped image. """ - if random.random() < self.p: + if torch.rand(1) < self.p: return F.hflip(img) return img @@ -526,25 +530,29 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomVerticalFlip(object): +class RandomVerticalFlip(torch.nn.Module): """Vertically flip the given PIL Image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): + super().__init__() self.p = p - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Tensor): Image to be flipped. Returns: - PIL Image: Randomly flipped image. + PIL Image or Tensor: Randomly flipped image. """ - if random.random() < self.p: + if torch.rand(1) < self.p: return F.vflip(img) return img