Skip to content

Make ColorJitter torchscriptable #2298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ def test_adjustments(self):
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertTrue(torch.equal(img, img_clone))

# test for class interface
f = transforms.ColorJitter(brightness=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(contrast=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(saturation=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(brightness=1)
scripted_fn = torch.jit.script(f)
scripted_fn(img)

def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
Expand Down
26 changes: 26 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,32 @@ def test_random_horizontal_flip(self):
def test_random_vertical_flip(self):
self._test_flip('vflip', 'RandomVerticalFlip')

def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
for _ in range(20):
factor = 3 * torch.rand(1).item()
tensor, _ = self._create_data()
pil_img = T.ToPILImage()(tensor)

for func in fns:
adjusted_tensor = getattr(F, func)(tensor, factor)
adjusted_pil_img = getattr(F, func)(pil_img, factor)

adjusted_pil_tensor = T.ToTensor()(adjusted_pil_img)
scripted_fn = torch.jit.script(getattr(F, func))
adjusted_tensor_script = scripted_fn(tensor, factor)

if not tensor.dtype.is_floating_point:
adjusted_tensor = adjusted_tensor.to(torch.float) / 255
adjusted_tensor_script = adjusted_tensor_script.to(torch.float) / 255

# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (adjusted_tensor - adjusted_pil_tensor).abs().max()
max_diff_scripted = (adjusted_tensor - adjusted_tensor_script).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)


if __name__ == '__main__':
unittest.main()
66 changes: 22 additions & 44 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,67 +633,61 @@ def ten_crop(img, size, vertical_flip=False):
return first_five + second_five


def adjust_brightness(img, brightness_factor):
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
"""Adjust brightness of an Image.

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Torch Tensor): Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.

Returns:
PIL Image: Brightness adjusted image.
PIL Image or Torch Tensor: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor)

enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
return F_t.adjust_brightness(img, brightness_factor)


def adjust_contrast(img, contrast_factor):
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
"""Adjust contrast of an Image.

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Torch Tensor): Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.

Returns:
PIL Image: Contrast adjusted image.
PIL Image or Torch Tensor: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.adjust_contrast(img, contrast_factor)

enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
return F_t.adjust_contrast(img, contrast_factor)


def adjust_saturation(img, saturation_factor):
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
"""Adjust color saturation of an image.

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Torch Tensor): Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.

Returns:
PIL Image: Saturation adjusted image.
PIL Image or Torch Tensor: Saturation adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.adjust_saturation(img, saturation_factor)

enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
return F_t.adjust_saturation(img, saturation_factor)


def adjust_hue(img, hue_factor):
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""Adjust hue of an image.

The image hue is adjusted by converting the image to HSV and
Expand All @@ -718,26 +712,10 @@ def adjust_hue(img, hue_factor):
Returns:
PIL Image: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img

h, s, v = img.convert('HSV').split()

np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)

img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))


def adjust_gamma(img, gamma, gain=1):
Expand Down
108 changes: 108 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
except ImportError:
accimage = None
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
import numpy as np


@torch.jit.unused
Expand Down Expand Up @@ -44,3 +45,110 @@ def vflip(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_TOP_BOTTOM)


@torch.jit.unused
def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an RGB image.

Args:
img (PIL Image): Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.

Returns:
PIL Image: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img


@torch.jit.unused
def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img


@torch.jit.unused
def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image.
Args:
img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image: Saturation adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img


@torch.jit.unused
def adjust_hue(img, hue_factor):
"""Adjust hue of an image.

The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.

`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.

See `Hue`_ for more details.

.. _Hue: https://en.wikipedia.org/wiki/Hue

Args:
img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.

Returns:
PIL Image: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img

h, s, v = img.convert('HSV').split()

np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')

img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img
42 changes: 33 additions & 9 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def __repr__(self):
return format_string


class ColorJitter(object):
class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast and saturation of an image.

Args:
Expand All @@ -882,20 +882,23 @@ class ColorJitter(object):
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
"""

def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False)

@torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError("If {} is a single number, it must be non negative.".format(name))
value = [center - value, center + value]
value = [center - float(value), center + float(value)]
if clip_first_on_zero:
value[0] = max(value[0], 0)
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError("{} values should be between {}".format(name, bound))
Expand All @@ -909,6 +912,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
return value

@staticmethod
@torch.jit.unused
def get_params(brightness, contrast, saturation, hue):
Copy link
Collaborator

@vfdev-5 vfdev-5 Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa @clemkoa looks like get_params becomes unused, why did we keep it ?

Copy link
Member

@fmassa fmassa Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backwards-compatibility. The way it was implemented was pretty non-canonical though, but I would be ok if we made a BC-breaking change to use it again by maybe returning a list of transform name / params somehow, but this seems lower priority for me I think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can give a shot to that as I'm working on F.adjust_hue and it is used by ColorJitter...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me, but can you do it in a separate PR which only does this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problems, can do it in a separate PR

"""Get a randomized transform to be applied on image.

Expand Down Expand Up @@ -941,17 +945,37 @@ def get_params(brightness, contrast, saturation, hue):

return transform

def __call__(self, img):
def forward(self, img):
"""
Args:
img (PIL Image): Input image.
img (PIL Image or Tensor): Input image.

Returns:
PIL Image: Color jittered image.
PIL Image or Tensor: Color jittered image.
"""
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
return transform(img)
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and self.brightness is not None:
brightness = self.brightness
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = F.adjust_brightness(img, brightness_factor)

if fn_id == 1 and self.contrast is not None:
contrast = self.contrast
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = F.adjust_contrast(img, contrast_factor)

if fn_id == 2 and self.saturation is not None:
saturation = self.saturation
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = F.adjust_saturation(img, saturation_factor)

if fn_id == 3 and self.hue is not None:
hue = self.hue
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = F.adjust_hue(img, hue_factor)

return img

def __repr__(self):
format_string = self.__class__.__name__ + '('
Expand Down