Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Color transforms #275

Merged
merged 10 commits into from
Oct 4, 2017
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,167 @@ def test_random_horizontal_flip(self):
p_value = stats.binom_test(num_horizontal, 100, p=0.5)
assert p_value > 0.05

def test_adjust_brightness(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')

# test 0
y_pil = transforms.adjust_brightness(x_pil, 1)
y_np = np.array(y_pil)
assert np.allclose(y_np, x_np)

# test 1
y_pil = transforms.adjust_brightness(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

# test 2
y_pil = transforms.adjust_brightness(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

def test_adjust_contrast(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')

# test 0
y_pil = transforms.adjust_contrast(x_pil, 1)
y_np = np.array(y_pil)
assert np.allclose(y_np, x_np)

# test 1
y_pil = transforms.adjust_contrast(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

# test 2
y_pil = transforms.adjust_contrast(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

def test_adjust_saturation(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')

# test 0
y_pil = transforms.adjust_saturation(x_pil, 1)
y_np = np.array(y_pil)
assert np.allclose(y_np, x_np)

# test 1
y_pil = transforms.adjust_saturation(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

# test 2
y_pil = transforms.adjust_saturation(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

def test_adjust_hue(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')

with self.assertRaises(ValueError):
transforms.adjust_hue(x_pil, -0.7)
transforms.adjust_hue(x_pil, 1)

# test 0: almost same as x_data but not exact.
# probably because hsv <-> rgb floating point ops
y_pil = transforms.adjust_hue(x_pil, 0)
y_np = np.array(y_pil)
y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

# test 1
y_pil = transforms.adjust_hue(x_pil, 0.25)
y_np = np.array(y_pil)
y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

# test 2
y_pil = transforms.adjust_hue(x_pil, -0.25)
y_np = np.array(y_pil)
y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

def test_adjust_gamma(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')

# test 0
y_pil = transforms.adjust_gamma(x_pil, 1)
y_np = np.array(y_pil)
assert np.allclose(y_np, x_np)

# test 1
y_pil = transforms.adjust_gamma(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

# test 2
y_pil = transforms.adjust_gamma(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)

def test_adjusts_L_mode(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_rgb = Image.fromarray(x_np, mode='RGB')

x_l = x_rgb.convert('L')
assert transforms.adjust_brightness(x_l, 2).mode == 'L'
assert transforms.adjust_saturation(x_l, 2).mode == 'L'
assert transforms.adjust_contrast(x_l, 2).mode == 'L'
assert transforms.adjust_hue(x_l, 0.4).mode == 'L'
assert transforms.adjust_gamma(x_l, 0.5).mode == 'L'

def test_color_jitter(self):
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)

x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
x_pil_2 = x_pil.convert('L')

for i in range(10):
y_pil = color_jitter(x_pil)
assert y_pil.mode == x_pil.mode

y_pil_2 = color_jitter(x_pil_2)
assert y_pil_2.mode == x_pil_2.mode


if __name__ == '__main__':
unittest.main()
205 changes: 204 additions & 1 deletion torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import math
import random
from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageEnhance
try:
import accimage
except ImportError:
Expand Down Expand Up @@ -348,6 +348,145 @@ def ten_crop(img, size, vertical_flip=False):
return first_five + second_five


def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an Image.

Args:
img (PIL.Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.

Returns:
PIL.Image: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img


def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image.

Args:
img (PIL.Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.

Returns:
PIL.Image: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img


def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image.

Args:
img (PIL.Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.

Returns:
PIL.Image: Saturation adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img


def adjust_hue(img, hue_factor):
"""Adjust hue of an image.

The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.

`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.

See https://en.wikipedia.org/wiki/Hue for more details on Hue.

Args:
img (PIL.Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.

Returns:
PIL.Image: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img

h, s, v = img.convert('HSV').split()

np_h = np.array(h, dtype='uint8')

This comment was marked as off-topic.

This comment was marked as off-topic.

# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')

img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img


def adjust_gamma(img, gamma, gain=1):
"""Perform gamma correction on an image.

Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:

I_out = 255 * gain * ((I_in / 255) ** gamma)

See https://en.wikipedia.org/wiki/Gamma_correction for more details.

Args:
img (PIL.Image): PIL Image to be adjusted.
gamma (float): Non negative real number. gamma larger than 1 make the
shadows darker, while gamma smaller than 1 make dark regions
lighter.
gain (float): The constant multiplier.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')

input_mode = img.mode
img = img.convert('RGB')

np_img = np.array(img, dtype='float32')

This comment was marked as off-topic.

np_img = 255 * gain * ((np_img / 255) ** gamma)
np_img = np.uint8(np.clip(np_img, 0, 255))

img = Image.fromarray(np_img, 'RGB').convert(input_mode)
return img


class Compose(object):
"""Composes several transforms together.

Expand Down Expand Up @@ -756,3 +895,67 @@ def __init__(self, size, vertical_flip=False):

def __call__(self, img):
return ten_crop(img, self.size, self.vertical_flip)


class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.

Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue

@staticmethod
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.

Arguments are same as that of __init__.

Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if brightness > 0:
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))

if contrast > 0:
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))

if saturation > 0:
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))

if hue > 0:
hue_factor = np.random.uniform(-hue, hue)
transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))

np.random.shuffle(transforms)
transform = Compose(transforms)

return transform

This comment was marked as off-topic.

This comment was marked as off-topic.


def __call__(self, img):
"""
Args:
img (PIL.Image): Input image.

Returns:
PIL.Image: Color jittered image.
"""
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
return transform(img)