From 35c2c09929612f5ff5e82a9e9138e562bdda316d Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Wed, 27 Sep 2017 16:50:24 +0530 Subject: [PATCH 01/10] Add adjust_hue and adjust_saturation --- torchvision/transforms.py | 64 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index deff17115d2..4c0d708862f 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -2,7 +2,7 @@ import torch import math import random -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageChops try: import accimage except ImportError: @@ -348,6 +348,68 @@ def ten_crop(img, size, vertical_flip=False): return first_five + second_five +def adjust_hue(img, delta): + """Adjust hue of an RGB image. + + `image` is an RGB image. The image hue is adjusted by converting the + image to HSV and cyclically rotating the intensities in hue channel (H). + The image is then converted back to RGB. + + `delta` must be in the interval `[-1, 1]`. + + Args: + image: RGB image. Size of the last dimension must be 3. + delta: float. How much to rotate the hue channel. 1 and -1 are + complete rotation in positive and negative direction respectively. + 0 means no rotation. + + Returns: + PIL.Image: Adjusted image. + """ + assert delta < 1 and delta >= -1, 'delta out of range.' + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype='uint8') + # uint8 addition take cares of rotation across boundaries + np_h += np.uint8(delta * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert('RGB') + return img + + +def adjust_saturation(img, saturation_factor): + """Adjust saturation of an RGB image. + + `image` is an RGB image. The image saturation is adjusted by converting the + image to HSV and multiplying the saturation (S) channel by + `saturation_factor` and clipping. The image is then converted back to RGB. + + Args: + image: RGB image or images. Size of the last dimension must be 3. + saturation_factor: float. Factor to multiply the saturation by. + + Returns: + Adjusted image(s), same shape and DType as `image`. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + h, s, v = img.convert('HSV').split() + + np_s = np.array(s) + np_s = np_s * saturation_factor + np_s = np.clip(np_s, 0, 255).astype('uint8') + s = Image.fromarray(np_s, 'L') + + img = Image.merge('HSV', (h, s, v)).convert('RGB') + return img + + class Compose(object): """Composes several transforms together. From 6e59812ed4d64d8ec8ccff89e6f608a6e9e5948f Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Wed, 27 Sep 2017 19:21:22 +0530 Subject: [PATCH 02/10] Add adjust_brightness, adjust_contrast Also * Change adjust_saturation to use pillow implementation * Documentation made clear --- torchvision/transforms.py | 101 ++++++++++++++++++++++++++------------ 1 file changed, 69 insertions(+), 32 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 4c0d708862f..b3c5335e99d 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -2,7 +2,7 @@ import torch import math import random -from PIL import Image, ImageOps, ImageChops +from PIL import Image, ImageOps, ImageEnhance try: import accimage except ImportError: @@ -348,63 +348,100 @@ def ten_crop(img, size, vertical_flip=False): return first_five + second_five -def adjust_hue(img, delta): - """Adjust hue of an RGB image. +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an Image. - `image` is an RGB image. The image hue is adjusted by converting the - image to HSV and cyclically rotating the intensities in hue channel (H). - The image is then converted back to RGB. + Args: + img (PIL.Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL.Image: Brightness adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img - `delta` must be in the interval `[-1, 1]`. + +def adjust_contrast(img, contrast_factor): + """Adjust brightness of an Image. Args: - image: RGB image. Size of the last dimension must be 3. - delta: float. How much to rotate the hue channel. 1 and -1 are - complete rotation in positive and negative direction respectively. - 0 means no rotation. + img (PIL.Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. Returns: - PIL.Image: Adjusted image. + PIL.Image: Contrast adjusted image. """ - assert delta < 1 and delta >= -1, 'delta out of range.' - if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - h, s, v = img.convert('HSV').split() + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img - np_h = np.array(h, dtype='uint8') - # uint8 addition take cares of rotation across boundaries - np_h += np.uint8(delta * 255) - h = Image.fromarray(np_h, 'L') - img = Image.merge('HSV', (h, s, v)).convert('RGB') +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + + Args: + img (PIL.Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and weight image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + Adjusted image(s), same shape and DType as `image`. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) return img -def adjust_saturation(img, saturation_factor): - """Adjust saturation of an RGB image. +def adjust_hue(img, hue_factor): + """Adjust hue of an image. - `image` is an RGB image. The image saturation is adjusted by converting the - image to HSV and multiplying the saturation (S) channel by - `saturation_factor` and clipping. The image is then converted back to RGB. + `image` is an RGB image. The image hue is adjusted by converting the + image to HSV and cyclically shifting the intensities in hue channel (H). + The image is then converted back to RGB. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See https://en.wikipedia.org/wiki/Hue for more details on Hue. Args: - image: RGB image or images. Size of the last dimension must be 3. - saturation_factor: float. Factor to multiply the saturation by. + img (PIL.Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. Returns: - Adjusted image(s), same shape and DType as `image`. + PIL.Image: Hue adjusted image. """ + assert hue_factor <= 0.5 and hue_factor >= -0.5, 'hue_factor out of range.' + if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) h, s, v = img.convert('HSV').split() - np_s = np.array(s) - np_s = np_s * saturation_factor - np_s = np.clip(np_s, 0, 255).astype('uint8') - s = Image.fromarray(np_s, 'L') + np_h = np.array(h, dtype='uint8') + # uint8 addition take cares of rotation across boundaries + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') img = Image.merge('HSV', (h, s, v)).convert('RGB') return img From 998c14e8fde385e1ab7838a8fc1fc234c784108d Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Wed, 27 Sep 2017 20:00:21 +0530 Subject: [PATCH 03/10] Add adjust_gamma --- torchvision/transforms.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index b3c5335e99d..08448fa2825 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -431,7 +431,8 @@ def adjust_hue(img, hue_factor): Returns: PIL.Image: Hue adjusted image. """ - assert hue_factor <= 0.5 and hue_factor >= -0.5, 'hue_factor out of range.' + if hue_factor <= 0.5 and hue_factor >= -0.5: + raise ValueError('hue_factor {} out of range.'.format(hue_factor)) if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) @@ -447,6 +448,39 @@ def adjust_hue(img, hue_factor): return img +def adjust_gamma(img, gamma, gain=1): + """Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities are adjusted based on the + following equation: + + I_out = 255 * gain * ((I_in / 255) ** gamma) + + See https://en.wikipedia.org/wiki/Gamma_correction for more details. + + Args: + img (PIL.Image): PIL Image to be adjusted. + gamma (float): Non negative real number. gamma larger than 1 make the + shadows darker, while gamma smaller than 1 make dark regions + lighter. + gain (float): The constant multiplier. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + img = img.convert('RGB') + + np_img = np.array(img) + np_img = 255 * gain * ((np_img / 255) ** gamma) + np_img = np.uint8(np.clip(np_img, 0, 255)) + + img = Image.fromarray(np_img, 'RGB') + return img + + class Compose(object): """Composes several transforms together. From 6fbf634f7b6933daf36813ae2317c89ee96ca2e8 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Thu, 28 Sep 2017 17:53:21 +0530 Subject: [PATCH 04/10] Add ColorJitter --- torchvision/transforms.py | 68 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 08448fa2825..1559eaf4f7f 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -394,7 +394,7 @@ def adjust_saturation(img, saturation_factor): Args: img (PIL.Image): PIL Image to be adjusted. saturation_factor (float): How much to adjust the saturation. 0 will - give a black and weight image, 1 will give the original image while + give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: @@ -431,7 +431,7 @@ def adjust_hue(img, hue_factor): Returns: PIL.Image: Hue adjusted image. """ - if hue_factor <= 0.5 and hue_factor >= -0.5: + if not(hue_factor <= 0.5 and hue_factor >= -0.5): raise ValueError('hue_factor {} out of range.'.format(hue_factor)) if not _is_pil_image(img): @@ -889,3 +889,67 @@ def __init__(self, size, vertical_flip=False): def __call__(self, img): return ten_crop(img, self.size, self.vertical_flip) + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation of an image. + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), brightness] + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), contrast] + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), saturation] + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + + Arguments are same as that of __init__. + + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + if brightness > 0: + brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) + + if contrast > 0: + contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) + + if saturation > 0: + saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) + + if hue > 0: + hue_factor = np.random.uniform(-hue, hue) + transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) + + np.random.shuffle(transforms) + transform = Compose(transforms) + + return transform + + def __call__(self, img): + """ + Args: + img (PIL.Image): Input image. + + Returns: + PIL.Image: Color jittered image. + """ + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + return transform(img) From 26734c60b70a62758e2f6cc0f6eec703e503ab21 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Fri, 29 Sep 2017 18:47:12 +0530 Subject: [PATCH 05/10] Address review comments --- torchvision/transforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 1559eaf4f7f..1461cb74b53 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -441,7 +441,8 @@ def adjust_hue(img, hue_factor): np_h = np.array(h, dtype='uint8') # uint8 addition take cares of rotation across boundaries - np_h += np.uint8(hue_factor * 255) + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) h = Image.fromarray(np_h, 'L') img = Image.merge('HSV', (h, s, v)).convert('RGB') @@ -473,7 +474,7 @@ def adjust_gamma(img, gamma, gain=1): img = img.convert('RGB') - np_img = np.array(img) + np_img = np.array(img, dtype='float32') np_img = 255 * gain * ((np_img / 255) ** gamma) np_img = np.uint8(np.clip(np_img, 0, 255)) From 7d303d1a091be99283e926a3eb224bbc27c52652 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Fri, 29 Sep 2017 23:48:50 +0530 Subject: [PATCH 06/10] Fix documentation for ColorJitter --- torchvision/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 1461cb74b53..0df6249d204 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -897,11 +897,11 @@ class ColorJitter(object): Args: brightness (float): How much to jitter brightness. brightness_factor - is chosen uniformly from [max(0, 1 - brightness), brightness] + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. contrast (float): How much to jitter contrast. contrast_factor - is chosen uniformly from [max(0, 1 - contrast), contrast] + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. saturation (float): How much to jitter saturation. saturation_factor - is chosen uniformly from [max(0, 1 - saturation), saturation] + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. hue(float): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue]. Should be >=0 and <= 0.5. """ From fe1e15f08784f974a5c96b4b218a09452d90ef57 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Mon, 2 Oct 2017 14:49:37 +0530 Subject: [PATCH 07/10] Address review comments 2 --- torchvision/transforms.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 0df6249d204..6814c637a04 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -369,7 +369,7 @@ def adjust_brightness(img, brightness_factor): def adjust_contrast(img, contrast_factor): - """Adjust brightness of an Image. + """Adjust contrast of an Image. Args: img (PIL.Image): PIL Image to be adjusted. @@ -398,7 +398,7 @@ def adjust_saturation(img, saturation_factor): 2 will enhance the saturation by a factor of 2. Returns: - Adjusted image(s), same shape and DType as `image`. + PIL.Image: Saturation adjusted image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) @@ -411,9 +411,9 @@ def adjust_saturation(img, saturation_factor): def adjust_hue(img, hue_factor): """Adjust hue of an image. - `image` is an RGB image. The image hue is adjusted by converting the - image to HSV and cyclically shifting the intensities in hue channel (H). - The image is then converted back to RGB. + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. `hue_factor` is the amount of shift in H channel and must be in the interval `[-0.5, 0.5]`. @@ -431,12 +431,13 @@ def adjust_hue(img, hue_factor): Returns: PIL.Image: Hue adjusted image. """ - if not(hue_factor <= 0.5 and hue_factor >= -0.5): - raise ValueError('hue_factor {} out of range.'.format(hue_factor)) + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + input_mode = img.mode h, s, v = img.convert('HSV').split() np_h = np.array(h, dtype='uint8') @@ -445,15 +446,15 @@ def adjust_hue(img, hue_factor): np_h += np.uint8(hue_factor * 255) h = Image.fromarray(np_h, 'L') - img = Image.merge('HSV', (h, s, v)).convert('RGB') + img = Image.merge('HSV', (h, s, v)).convert(input_mode) return img def adjust_gamma(img, gamma, gain=1): """Perform gamma correction on an image. - Also known as Power Law Transform. Intensities are adjusted based on the - following equation: + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: I_out = 255 * gain * ((I_in / 255) ** gamma) @@ -472,13 +473,14 @@ def adjust_gamma(img, gamma, gain=1): if gamma < 0: raise ValueError('Gamma should be a non-negative real number') + input_mode = img.mode img = img.convert('RGB') np_img = np.array(img, dtype='float32') np_img = 255 * gain * ((np_img / 255) ** gamma) np_img = np.uint8(np.clip(np_img, 0, 255)) - img = Image.fromarray(np_img, 'RGB') + img = Image.fromarray(np_img, 'RGB').convert(input_mode) return img From d670c0734d5f3769b462bab2374c932bbb7a8b0d Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Mon, 2 Oct 2017 23:50:44 +0530 Subject: [PATCH 08/10] Fallback to adjust_hue in case of BW images --- torchvision/transforms.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 6814c637a04..862661e9e74 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -438,6 +438,9 @@ def adjust_hue(img, hue_factor): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + h, s, v = img.convert('HSV').split() np_h = np.array(h, dtype='uint8') From b08d89d867b04a44cdc768d1c208ceb1dd3ccd87 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Mon, 2 Oct 2017 23:50:55 +0530 Subject: [PATCH 09/10] Add tests --- test/test_transforms.py | 161 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index a53f14d4d92..5724d0a1b3f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -422,6 +422,167 @@ def test_random_horizontal_flip(self): p_value = stats.binom_test(num_horizontal, 100, p=0.5) assert p_value > 0.05 + def test_adjust_brightness(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transforms.adjust_brightness(x_pil, 1) + y_np = np.array(y_pil) + assert np.allclose(y_np, x_np) + + # test 1 + y_pil = transforms.adjust_brightness(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_brightness(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjust_contrast(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transforms.adjust_contrast(x_pil, 1) + y_np = np.array(y_pil) + assert np.allclose(y_np, x_np) + + # test 1 + y_pil = transforms.adjust_contrast(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_contrast(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjust_saturation(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transforms.adjust_saturation(x_pil, 1) + y_np = np.array(y_pil) + assert np.allclose(y_np, x_np) + + # test 1 + y_pil = transforms.adjust_saturation(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_saturation(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjust_hue(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + with self.assertRaises(ValueError): + transforms.adjust_hue(x_pil, -0.7) + transforms.adjust_hue(x_pil, 1) + + # test 0: almost same as x_data but not exact. + # probably because hsv <-> rgb floating point ops + y_pil = transforms.adjust_hue(x_pil, 0) + y_np = np.array(y_pil) + y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 1 + y_pil = transforms.adjust_hue(x_pil, 0.25) + y_np = np.array(y_pil) + y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_hue(x_pil, -0.25) + y_np = np.array(y_pil) + y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjust_gamma(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transforms.adjust_gamma(x_pil, 1) + y_np = np.array(y_pil) + assert np.allclose(y_np, x_np) + + # test 1 + y_pil = transforms.adjust_gamma(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + # test 2 + y_pil = transforms.adjust_gamma(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + assert np.allclose(y_np, y_ans) + + def test_adjusts_L_mode(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_rgb = Image.fromarray(x_np, mode='RGB') + + x_l = x_rgb.convert('L') + assert transforms.adjust_brightness(x_l, 2).mode == 'L' + assert transforms.adjust_saturation(x_l, 2).mode == 'L' + assert transforms.adjust_contrast(x_l, 2).mode == 'L' + assert transforms.adjust_hue(x_l, 0.4).mode == 'L' + assert transforms.adjust_gamma(x_l, 0.5).mode == 'L' + + def test_color_jitter(self): + color_jitter = transforms.ColorJitter(2, 2, 2, 0.1) + + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + + for i in range(10): + y_pil = color_jitter(x_pil) + assert y_pil.mode == x_pil.mode + + y_pil_2 = color_jitter(x_pil_2) + assert y_pil_2.mode == x_pil_2.mode + if __name__ == '__main__': unittest.main() From f04be260e8999c38d73c23261e910f7718be0604 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Wed, 4 Oct 2017 23:58:46 +0530 Subject: [PATCH 10/10] fix dtype --- torchvision/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 862661e9e74..15fec0c9c65 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -443,7 +443,7 @@ def adjust_hue(img, hue_factor): h, s, v = img.convert('HSV').split() - np_h = np.array(h, dtype='uint8') + np_h = np.array(h, dtype=np.uint8) # uint8 addition take cares of rotation across boundaries with np.errstate(over='ignore'): np_h += np.uint8(hue_factor * 255) @@ -479,7 +479,7 @@ def adjust_gamma(img, gamma, gain=1): input_mode = img.mode img = img.convert('RGB') - np_img = np.array(img, dtype='float32') + np_img = np.array(img, dtype=np.float32) np_img = 255 * gain * ((np_img / 255) ** gamma) np_img = np.uint8(np.clip(np_img, 0, 255))