Skip to content

Commit b51a2c3

Browse files
yaox12fmassa
authored andcommitted
ColorJitter Enhancement (#548)
* ColorJitter Enhancement * reduce redundancy * improve functional * check input in init * fix ci fail
1 parent d6c7900 commit b51a2c3

File tree

1 file changed

+45
-20
lines changed

1 file changed

+45
-20
lines changed

torchvision/transforms/transforms.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -728,20 +728,44 @@ class ColorJitter(object):
728728
"""Randomly change the brightness, contrast and saturation of an image.
729729
730730
Args:
731-
brightness (float): How much to jitter brightness. brightness_factor
732-
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
733-
contrast (float): How much to jitter contrast. contrast_factor
734-
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
735-
saturation (float): How much to jitter saturation. saturation_factor
736-
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
737-
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
738-
[-hue, hue]. Should be >=0 and <= 0.5.
731+
brightness (float or tuple of float (min, max)): How much to jitter brightness.
732+
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
733+
or the given [min, max]. Should be non negative numbers.
734+
contrast (float or tuple of float (min, max)): How much to jitter contrast.
735+
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
736+
or the given [min, max]. Should be non negative numbers.
737+
saturation (float or tuple of float (min, max)): How much to jitter saturation.
738+
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
739+
or the given [min, max]. Should be non negative numbers.
740+
hue (float or tuple of float (min, max)): How much to jitter hue.
741+
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
742+
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
739743
"""
740744
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
741-
self.brightness = brightness
742-
self.contrast = contrast
743-
self.saturation = saturation
744-
self.hue = hue
745+
self.brightness = self._check_input(brightness, 'brightness')
746+
self.contrast = self._check_input(contrast, 'contrast')
747+
self.saturation = self._check_input(saturation, 'saturation')
748+
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
749+
clip_first_on_zero=False)
750+
751+
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
752+
if isinstance(value, numbers.Number):
753+
if value < 0:
754+
raise ValueError("If {} is a single number, it must be non negative.".format(name))
755+
value = [center - value, center + value]
756+
if clip_first_on_zero:
757+
value[0] = max(value[0], 0)
758+
elif isinstance(value, (tuple, list)) and len(value) == 2:
759+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
760+
raise ValueError("{} values should be between {}".format(name, bound))
761+
else:
762+
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
763+
764+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
765+
# or (0., 0.) for hue, do nothing
766+
if value[0] == value[1] == center:
767+
value = None
768+
return value
745769

746770
@staticmethod
747771
def get_params(brightness, contrast, saturation, hue):
@@ -754,20 +778,21 @@ def get_params(brightness, contrast, saturation, hue):
754778
saturation in a random order.
755779
"""
756780
transforms = []
757-
if brightness > 0:
758-
brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
781+
782+
if brightness is not None:
783+
brightness_factor = random.uniform(brightness[0], brightness[1])
759784
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
760785

761-
if contrast > 0:
762-
contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
786+
if contrast is not None:
787+
contrast_factor = random.uniform(contrast[0], contrast[1])
763788
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
764789

765-
if saturation > 0:
766-
saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
790+
if saturation is not None:
791+
saturation_factor = random.uniform(saturation[0], saturation[1])
767792
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
768793

769-
if hue > 0:
770-
hue_factor = random.uniform(-hue, hue)
794+
if hue is not None:
795+
hue_factor = random.uniform(hue[0], hue[1])
771796
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
772797

773798
random.shuffle(transforms)

0 commit comments

Comments
 (0)