@@ -118,6 +118,54 @@ def adjust_contrast(img, contrast_factor):
118
118
return _blend (img , mean , contrast_factor )
119
119
120
120
121
+ def adjust_hue (img , hue_factor ):
122
+ """Adjust hue of an image.
123
+
124
+ The image hue is adjusted by converting the image to HSV and
125
+ cyclically shifting the intensities in the hue channel (H).
126
+ The image is then converted back to original image mode.
127
+
128
+ `hue_factor` is the amount of shift in H channel and must be in the
129
+ interval `[-0.5, 0.5]`.
130
+
131
+ See `Hue`_ for more details.
132
+
133
+ .. _Hue: https://en.wikipedia.org/wiki/Hue
134
+
135
+ Args:
136
+ img (Tensor): Image to be adjusted. Image type is either uint8 or float.
137
+ hue_factor (float): How much to shift the hue channel. Should be in
138
+ [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
139
+ HSV space in positive and negative direction respectively.
140
+ 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
141
+ with complementary colors while 0 gives the original image.
142
+
143
+ Returns:
144
+ Tensor: Hue adjusted image.
145
+ """
146
+ if not (- 0.5 <= hue_factor <= 0.5 ):
147
+ raise ValueError ('hue_factor ({}) is not in [-0.5, 0.5].' .format (hue_factor ))
148
+
149
+ if not _is_tensor_a_torch_image (img ):
150
+ raise TypeError ('tensor is not a torch image.' )
151
+
152
+ orig_dtype = img .dtype
153
+ if img .dtype == torch .uint8 :
154
+ img = img .to (dtype = torch .float32 ) / 255.0
155
+
156
+ img = _rgb2hsv (img )
157
+ h , s , v = img .unbind (0 )
158
+ h += hue_factor
159
+ h = h % 1.0
160
+ img = torch .stack ((h , s , v ))
161
+ img_hue_adj = _hsv2rgb (img )
162
+
163
+ if orig_dtype == torch .uint8 :
164
+ img_hue_adj = (img_hue_adj * 255.0 ).to (dtype = orig_dtype )
165
+
166
+ return img_hue_adj
167
+
168
+
121
169
def adjust_saturation (img , saturation_factor ):
122
170
# type: (Tensor, float) -> Tensor
123
171
"""Adjust color saturation of an RGB image.
@@ -235,3 +283,47 @@ def _blend(img1, img2, ratio):
235
283
# type: (Tensor, Tensor, float) -> Tensor
236
284
bound = 1 if img1 .dtype in [torch .half , torch .float32 , torch .float64 ] else 255
237
285
return (ratio * img1 + (1 - ratio ) * img2 ).clamp (0 , bound ).to (img1 .dtype )
286
+
287
+
288
+ def _rgb2hsv (img ):
289
+ r , g , b = img .unbind (0 )
290
+
291
+ maxc , _ = torch .max (img , dim = 0 )
292
+ minc , _ = torch .min (img , dim = 0 )
293
+
294
+ cr = maxc - minc
295
+ s = cr / maxc
296
+ rc = (maxc - r ) / cr
297
+ gc = (maxc - g ) / cr
298
+ bc = (maxc - b ) / cr
299
+
300
+ t = (maxc != minc )
301
+ s = t * s
302
+ hr = (maxc == r ) * (bc - gc )
303
+ hg = ((maxc == g ) & (maxc != r )) * (2.0 + rc - bc )
304
+ hb = ((maxc != g ) & (maxc != r )) * (4.0 + gc - rc )
305
+ h = (hr + hg + hb )
306
+ h = t * h
307
+ h = torch .fmod ((h / 6.0 + 1.0 ), 1.0 )
308
+ return torch .stack ((h , s , maxc ))
309
+
310
+
311
+ def _hsv2rgb (img ):
312
+ h , s , v = img .unbind (0 )
313
+ i = torch .floor (h * 6.0 )
314
+ f = (h * 6.0 ) - i
315
+ i = i .to (dtype = torch .int32 )
316
+
317
+ p = torch .clamp ((v * (1.0 - s )), 0.0 , 1.0 )
318
+ q = torch .clamp ((v * (1.0 - s * f )), 0.0 , 1.0 )
319
+ t = torch .clamp ((v * (1.0 - s * (1.0 - f ))), 0.0 , 1.0 )
320
+ i = i % 6
321
+
322
+ mask = i == torch .arange (6 )[:, None , None ]
323
+
324
+ a1 = torch .stack ((v , q , p , p , t , v ))
325
+ a2 = torch .stack ((t , v , v , q , p , p ))
326
+ a3 = torch .stack ((p , p , t , v , v , q ))
327
+ a4 = torch .stack ((a1 , a2 , a3 ))
328
+
329
+ return torch .einsum ("ijk, xijk -> xjk" , mask .to (dtype = img .dtype ), a4 )
0 commit comments