diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index e7d058e8da2..3eead60aed5 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -97,6 +97,23 @@ def test_adjustments(self): self.assertLess(max_diff_scripted, 5 / 255 + 1e-5) self.assertTrue(torch.equal(img, img_clone)) + # test for class interface + f = transforms.ColorJitter(brightness=factor.item()) + scripted_fn = torch.jit.script(f) + scripted_fn(img) + + f = transforms.ColorJitter(contrast=factor.item()) + scripted_fn = torch.jit.script(f) + scripted_fn(img) + + f = transforms.ColorJitter(saturation=factor.item()) + scripted_fn = torch.jit.script(f) + scripted_fn(img) + + f = transforms.ColorJitter(brightness=1) + scripted_fn = torch.jit.script(f) + scripted_fn(img) + def test_rgb_to_grayscale(self): script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale) img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index a9f76fcff07..7791dd8b4f9 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -39,6 +39,32 @@ def test_random_horizontal_flip(self): def test_random_vertical_flip(self): self._test_flip('vflip', 'RandomVerticalFlip') + def test_adjustments(self): + fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation'] + for _ in range(20): + factor = 3 * torch.rand(1).item() + tensor, _ = self._create_data() + pil_img = T.ToPILImage()(tensor) + + for func in fns: + adjusted_tensor = getattr(F, func)(tensor, factor) + adjusted_pil_img = getattr(F, func)(pil_img, factor) + + adjusted_pil_tensor = T.ToTensor()(adjusted_pil_img) + scripted_fn = torch.jit.script(getattr(F, func)) + adjusted_tensor_script = scripted_fn(tensor, factor) + + if not tensor.dtype.is_floating_point: + adjusted_tensor = adjusted_tensor.to(torch.float) / 255 + adjusted_tensor_script = adjusted_tensor_script.to(torch.float) / 255 + + # F uses uint8 and F_t uses float, so there is a small + # difference in values caused by (at most 5) truncations. + max_diff = (adjusted_tensor - adjusted_pil_tensor).abs().max() + max_diff_scripted = (adjusted_tensor - adjusted_tensor_script).abs().max() + self.assertLess(max_diff, 5 / 255 + 1e-5) + self.assertLess(max_diff_scripted, 5 / 255 + 1e-5) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 3c8c364c3fa..d19b26e36b2 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -633,67 +633,61 @@ def ten_crop(img, size, vertical_flip=False): return first_five + second_five -def adjust_brightness(img, brightness_factor): +def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: """Adjust brightness of an Image. Args: - img (PIL Image): PIL Image to be adjusted. + img (PIL Image or Torch Tensor): Image to be adjusted. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image while 2 increases the brightness by a factor of 2. Returns: - PIL Image: Brightness adjusted image. + PIL Image or Torch Tensor: Brightness adjusted image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_brightness(img, brightness_factor) - enhancer = ImageEnhance.Brightness(img) - img = enhancer.enhance(brightness_factor) - return img + return F_t.adjust_brightness(img, brightness_factor) -def adjust_contrast(img, contrast_factor): +def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: """Adjust contrast of an Image. Args: - img (PIL Image): PIL Image to be adjusted. + img (PIL Image or Torch Tensor): Image to be adjusted. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. Returns: - PIL Image: Contrast adjusted image. + PIL Image or Torch Tensor: Contrast adjusted image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_contrast(img, contrast_factor) - enhancer = ImageEnhance.Contrast(img) - img = enhancer.enhance(contrast_factor) - return img + return F_t.adjust_contrast(img, contrast_factor) -def adjust_saturation(img, saturation_factor): +def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: """Adjust color saturation of an image. Args: - img (PIL Image): PIL Image to be adjusted. + img (PIL Image or Torch Tensor): Image to be adjusted. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: - PIL Image: Saturation adjusted image. + PIL Image or Torch Tensor: Saturation adjusted image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_saturation(img, saturation_factor) - enhancer = ImageEnhance.Color(img) - img = enhancer.enhance(saturation_factor) - return img + return F_t.adjust_saturation(img, saturation_factor) -def adjust_hue(img, hue_factor): +def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: """Adjust hue of an image. The image hue is adjusted by converting the image to HSV and @@ -718,26 +712,10 @@ def adjust_hue(img, hue_factor): Returns: PIL Image: Hue adjusted image. """ - if not(-0.5 <= hue_factor <= 0.5): - raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) - - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - input_mode = img.mode - if input_mode in {'L', '1', 'I', 'F'}: - return img - - h, s, v = img.convert('HSV').split() - - np_h = np.array(h, dtype=np.uint8) - # uint8 addition take cares of rotation across boundaries - with np.errstate(over='ignore'): - np_h += np.uint8(hue_factor * 255) - h = Image.fromarray(np_h, 'L') + if not isinstance(img, torch.Tensor): + return F_pil.adjust_hue(img, hue_factor) - img = Image.merge('HSV', (h, s, v)).convert(input_mode) - return img + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) def adjust_gamma(img, gamma, gain=1): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index e387924ad36..84e27e79040 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -4,6 +4,7 @@ except ImportError: accimage = None from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION +import numpy as np @torch.jit.unused @@ -44,3 +45,110 @@ def vflip(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.transpose(Image.FLIP_TOP_BOTTOM) + + +@torch.jit.unused +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an RGB image. + + Args: + img (PIL Image): Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL Image: Brightness adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +@torch.jit.unused +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + Args: + img (PIL Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + Returns: + PIL Image: Contrast adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +@torch.jit.unused +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + Args: + img (PIL Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + Returns: + PIL Image: Saturation adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +@torch.jit.unused +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See `Hue`_ for more details. + + .. _Hue: https://en.wikipedia.org/wiki/Hue + + Args: + img (PIL Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL Image: Hue adjusted image. + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 5c202f384ee..812d82e3825 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -865,7 +865,7 @@ def __repr__(self): return format_string -class ColorJitter(object): +class ColorJitter(torch.nn.Module): """Randomly change the brightness, contrast and saturation of an image. Args: @@ -882,20 +882,23 @@ class ColorJitter(object): hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + super().__init__() self.brightness = self._check_input(brightness, 'brightness') self.contrast = self._check_input(contrast, 'contrast') self.saturation = self._check_input(saturation, 'saturation') self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + @torch.jit.unused def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): if isinstance(value, numbers.Number): if value < 0: raise ValueError("If {} is a single number, it must be non negative.".format(name)) - value = [center - value, center + value] + value = [center - float(value), center + float(value)] if clip_first_on_zero: - value[0] = max(value[0], 0) + value[0] = max(value[0], 0.0) elif isinstance(value, (tuple, list)) and len(value) == 2: if not bound[0] <= value[0] <= value[1] <= bound[1]: raise ValueError("{} values should be between {}".format(name, bound)) @@ -909,6 +912,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs return value @staticmethod + @torch.jit.unused def get_params(brightness, contrast, saturation, hue): """Get a randomized transform to be applied on image. @@ -941,17 +945,37 @@ def get_params(brightness, contrast, saturation, hue): return transform - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Input image. + img (PIL Image or Tensor): Input image. Returns: - PIL Image: Color jittered image. + PIL Image or Tensor: Color jittered image. """ - transform = self.get_params(self.brightness, self.contrast, - self.saturation, self.hue) - return transform(img) + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and self.brightness is not None: + brightness = self.brightness + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = F.adjust_brightness(img, brightness_factor) + + if fn_id == 1 and self.contrast is not None: + contrast = self.contrast + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = F.adjust_contrast(img, contrast_factor) + + if fn_id == 2 and self.saturation is not None: + saturation = self.saturation + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = F.adjust_saturation(img, saturation_factor) + + if fn_id == 3 and self.hue is not None: + hue = self.hue + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = F.adjust_hue(img, hue_factor) + + return img def __repr__(self): format_string = self.__class__.__name__ + '('