@@ -728,20 +728,44 @@ class ColorJitter(object):
728
728
"""Randomly change the brightness, contrast and saturation of an image.
729
729
730
730
Args:
731
- brightness (float): How much to jitter brightness. brightness_factor
732
- is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
733
- contrast (float): How much to jitter contrast. contrast_factor
734
- is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
735
- saturation (float): How much to jitter saturation. saturation_factor
736
- is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
737
- hue(float): How much to jitter hue. hue_factor is chosen uniformly from
738
- [-hue, hue]. Should be >=0 and <= 0.5.
731
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
732
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
733
+ or the given [min, max]. Should be non negative numbers.
734
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
735
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
736
+ or the given [min, max]. Should be non negative numbers.
737
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
738
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
739
+ or the given [min, max]. Should be non negative numbers.
740
+ hue (float or tuple of float (min, max)): How much to jitter hue.
741
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
742
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
739
743
"""
740
744
def __init__ (self , brightness = 0 , contrast = 0 , saturation = 0 , hue = 0 ):
741
- self .brightness = brightness
742
- self .contrast = contrast
743
- self .saturation = saturation
744
- self .hue = hue
745
+ self .brightness = self ._check_input (brightness , 'brightness' )
746
+ self .contrast = self ._check_input (contrast , 'contrast' )
747
+ self .saturation = self ._check_input (saturation , 'saturation' )
748
+ self .hue = self ._check_input (hue , 'hue' , center = 0 , bound = (- 0.5 , 0.5 ),
749
+ clip_first_on_zero = False )
750
+
751
+ def _check_input (self , value , name , center = 1 , bound = (0 , float ('inf' )), clip_first_on_zero = True ):
752
+ if isinstance (value , numbers .Number ):
753
+ if value < 0 :
754
+ raise ValueError ("If {} is a single number, it must be non negative." .format (name ))
755
+ value = [center - value , center + value ]
756
+ if clip_first_on_zero :
757
+ value [0 ] = max (value [0 ], 0 )
758
+ elif isinstance (value , (tuple , list )) and len (value ) == 2 :
759
+ if not bound [0 ] <= value [0 ] <= value [1 ] <= bound [1 ]:
760
+ raise ValueError ("{} values should be between {}" .format (name , bound ))
761
+ else :
762
+ raise TypeError ("{} should be a single number or a list/tuple with lenght 2." .format (name ))
763
+
764
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
765
+ # or (0., 0.) for hue, do nothing
766
+ if value [0 ] == value [1 ] == center :
767
+ value = None
768
+ return value
745
769
746
770
@staticmethod
747
771
def get_params (brightness , contrast , saturation , hue ):
@@ -754,20 +778,21 @@ def get_params(brightness, contrast, saturation, hue):
754
778
saturation in a random order.
755
779
"""
756
780
transforms = []
757
- if brightness > 0 :
758
- brightness_factor = random .uniform (max (0 , 1 - brightness ), 1 + brightness )
781
+
782
+ if brightness is not None :
783
+ brightness_factor = random .uniform (brightness [0 ], brightness [1 ])
759
784
transforms .append (Lambda (lambda img : F .adjust_brightness (img , brightness_factor )))
760
785
761
- if contrast > 0 :
762
- contrast_factor = random .uniform (max ( 0 , 1 - contrast ), 1 + contrast )
786
+ if contrast is not None :
787
+ contrast_factor = random .uniform (contrast [ 0 ], contrast [ 1 ] )
763
788
transforms .append (Lambda (lambda img : F .adjust_contrast (img , contrast_factor )))
764
789
765
- if saturation > 0 :
766
- saturation_factor = random .uniform (max ( 0 , 1 - saturation ), 1 + saturation )
790
+ if saturation is not None :
791
+ saturation_factor = random .uniform (saturation [ 0 ], saturation [ 1 ] )
767
792
transforms .append (Lambda (lambda img : F .adjust_saturation (img , saturation_factor )))
768
793
769
- if hue > 0 :
770
- hue_factor = random .uniform (- hue , hue )
794
+ if hue is not None :
795
+ hue_factor = random .uniform (hue [ 0 ] , hue [ 1 ] )
771
796
transforms .append (Lambda (lambda img : F .adjust_hue (img , hue_factor )))
772
797
773
798
random .shuffle (transforms )
0 commit comments