diff --git a/test/test_transforms.py b/test/test_transforms.py index a53f14d4d92..5724d0a1b3f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -422,6 +422,167 @@ def test_random_horizontal_flip(self): p_value = stats.binom_test(num_horizontal, 100, p=0.5) assert p_value > 0.05 + def test_adjust_brightness(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transforms.adjust_brightness(x_pil, 1) + y_np = np.array(y_pil) + assert np.allclose(y_np, x_np) + + # test 1 + y_pil = transforms.adjust_brightness(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_brightness(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjust_contrast(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transforms.adjust_contrast(x_pil, 1) + y_np = np.array(y_pil) + assert np.allclose(y_np, x_np) + + # test 1 + y_pil = transforms.adjust_contrast(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_contrast(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjust_saturation(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transforms.adjust_saturation(x_pil, 1) + y_np = np.array(y_pil) + assert np.allclose(y_np, x_np) + + # test 1 + y_pil = transforms.adjust_saturation(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_saturation(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjust_hue(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + with self.assertRaises(ValueError): + transforms.adjust_hue(x_pil, -0.7) + transforms.adjust_hue(x_pil, 1) + + # test 0: almost same as x_data but not exact. + # probably because hsv <-> rgb floating point ops + y_pil = transforms.adjust_hue(x_pil, 0) + y_np = np.array(y_pil) + y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 1 + y_pil = transforms.adjust_hue(x_pil, 0.25) + y_np = np.array(y_pil) + y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_hue(x_pil, -0.25) + y_np = np.array(y_pil) + y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjust_gamma(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transforms.adjust_gamma(x_pil, 1) + y_np = np.array(y_pil) + assert np.allclose(y_np, x_np) + + # test 1 + y_pil = transforms.adjust_gamma(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_gamma(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjusts_L_mode(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_rgb = Image.fromarray(x_np, mode='RGB') + + x_l = x_rgb.convert('L') + assert transforms.adjust_brightness(x_l, 2).mode == 'L' + assert transforms.adjust_saturation(x_l, 2).mode == 'L' + assert transforms.adjust_contrast(x_l, 2).mode == 'L' + assert transforms.adjust_hue(x_l, 0.4).mode == 'L' + assert transforms.adjust_gamma(x_l, 0.5).mode == 'L' + + def test_color_jitter(self): + color_jitter = transforms.ColorJitter(2, 2, 2, 0.1) + + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + + for i in range(10): + y_pil = color_jitter(x_pil) + assert y_pil.mode == x_pil.mode + + y_pil_2 = color_jitter(x_pil_2) + assert y_pil_2.mode == x_pil_2.mode + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms.py b/torchvision/transforms.py index deff17115d2..15fec0c9c65 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -2,7 +2,7 @@ import torch import math import random -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageEnhance try: import accimage except ImportError: @@ -348,6 +348,145 @@ def ten_crop(img, size, vertical_flip=False): return first_five + second_five +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an Image. + + Args: + img (PIL.Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL.Image: Brightness adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + + Args: + img (PIL.Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + PIL.Image: Contrast adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + + Args: + img (PIL.Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + PIL.Image: Saturation adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See https://en.wikipedia.org/wiki/Hue for more details on Hue. + + Args: + img (PIL.Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL.Image: Hue adjusted image. + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img + + +def adjust_gamma(img, gamma, gain=1): + """Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + I_out = 255 * gain * ((I_in / 255) ** gamma) + + See https://en.wikipedia.org/wiki/Gamma_correction for more details. + + Args: + img (PIL.Image): PIL Image to be adjusted. + gamma (float): Non negative real number. gamma larger than 1 make the + shadows darker, while gamma smaller than 1 make dark regions + lighter. + gain (float): The constant multiplier. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + input_mode = img.mode + img = img.convert('RGB') + + np_img = np.array(img, dtype=np.float32) + np_img = 255 * gain * ((np_img / 255) ** gamma) + np_img = np.uint8(np.clip(np_img, 0, 255)) + + img = Image.fromarray(np_img, 'RGB').convert(input_mode) + return img + + class Compose(object): """Composes several transforms together. @@ -756,3 +895,67 @@ def __init__(self, size, vertical_flip=False): def __call__(self, img): return ten_crop(img, self.size, self.vertical_flip) + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation of an image. + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + + Arguments are same as that of __init__. + + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + if brightness > 0: + brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) + + if contrast > 0: + contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) + + if saturation > 0: + saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) + + if hue > 0: + hue_factor = np.random.uniform(-hue, hue) + transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) + + np.random.shuffle(transforms) + transform = Compose(transforms) + + return transform + + def __call__(self, img): + """ + Args: + img (PIL.Image): Input image. + + Returns: + PIL.Image: Color jittered image. + """ + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + return transform(img)