@@ -143,7 +143,104 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe
143143 return adjust_sharpness_image_pil (inpt , sharpness_factor = sharpness_factor )
144144
145145
146- adjust_hue_image_tensor = _FT .adjust_hue
146+ def _rgb_to_hsv (image : torch .Tensor ) -> torch .Tensor :
147+ r , g , _ = image .unbind (dim = - 3 )
148+
149+ # Implementation is based on
150+ # https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330
151+ minc , maxc = torch .aminmax (image , dim = - 3 )
152+
153+ # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
154+ # from happening in the results, because
155+ # + S channel has division by `maxc`, which is zero only if `maxc = minc`
156+ # + H channel has division by `(maxc - minc)`.
157+ #
158+ # Instead of overwriting NaN afterwards, we just prevent it from occuring so
159+ # we don't need to deal with it in case we save the NaN in a buffer in
160+ # backprop, if it is ever supported, but it doesn't hurt to do so.
161+ eqc = maxc == minc
162+
163+ channels_range = maxc - minc
164+ # Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine.
165+ ones = torch .ones_like (maxc )
166+ s = channels_range / torch .where (eqc , ones , maxc )
167+ # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
168+ # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
169+ # would not matter what values `rc`, `gc`, and `bc` have here, and thus
170+ # replacing denominator with 1 when `eqc` is fine.
171+ channels_range_divisor = torch .where (eqc , ones , channels_range ).unsqueeze_ (dim = - 3 )
172+ rc , gc , bc = ((maxc .unsqueeze (dim = - 3 ) - image ) / channels_range_divisor ).unbind (dim = - 3 )
173+
174+ mask_maxc_neq_r = maxc != r
175+ mask_maxc_eq_g = maxc == g
176+ mask_maxc_neq_g = ~ mask_maxc_eq_g
177+
178+ hr = (bc - gc ).mul_ (~ mask_maxc_neq_r )
179+ hg = (2.0 + rc ).sub_ (bc ).mul_ (mask_maxc_eq_g & mask_maxc_neq_r )
180+ hb = (4.0 + gc ).sub_ (rc ).mul_ (mask_maxc_neq_g & mask_maxc_neq_r )
181+
182+ h = hr .add_ (hg ).add_ (hb )
183+ h = h .div_ (6.0 ).add_ (1.0 ).fmod_ (1.0 )
184+ return torch .stack ((h , s , maxc ), dim = - 3 )
185+
186+
187+ def _hsv_to_rgb (img : torch .Tensor ) -> torch .Tensor :
188+ h , s , v = img .unbind (dim = - 3 )
189+ h6 = h * 6
190+ i = torch .floor (h6 )
191+ f = (h6 ) - i
192+ i = i .to (dtype = torch .int32 )
193+
194+ p = (v * (1.0 - s )).clamp_ (0.0 , 1.0 )
195+ q = (v * (1.0 - s * f )).clamp_ (0.0 , 1.0 )
196+ t = (v * (1.0 - s * (1.0 - f ))).clamp_ (0.0 , 1.0 )
197+ i .remainder_ (6 )
198+
199+ mask = i .unsqueeze (dim = - 3 ) == torch .arange (6 , device = i .device ).view (- 1 , 1 , 1 )
200+
201+ a1 = torch .stack ((v , q , p , p , t , v ), dim = - 3 )
202+ a2 = torch .stack ((t , v , v , q , p , p ), dim = - 3 )
203+ a3 = torch .stack ((p , p , t , v , v , q ), dim = - 3 )
204+ a4 = torch .stack ((a1 , a2 , a3 ), dim = - 4 )
205+
206+ return (a4 .mul_ (mask .to (dtype = img .dtype ).unsqueeze (dim = - 4 ))).sum (dim = - 3 )
207+
208+
209+ def adjust_hue_image_tensor (image : torch .Tensor , hue_factor : float ) -> torch .Tensor :
210+ if not (- 0.5 <= hue_factor <= 0.5 ):
211+ raise ValueError (f"hue_factor ({ hue_factor } ) is not in [-0.5, 0.5]." )
212+
213+ if not (isinstance (image , torch .Tensor )):
214+ raise TypeError ("Input img should be Tensor image" )
215+
216+ c = get_num_channels_image_tensor (image )
217+
218+ if c not in [1 , 3 ]:
219+ raise TypeError (f"Input image tensor permitted channel values are { [1 , 3 ]} , but found { c } " )
220+
221+ if c == 1 : # Match PIL behaviour
222+ return image
223+
224+ if image .numel () == 0 :
225+ # exit earlier on empty images
226+ return image
227+
228+ orig_dtype = image .dtype
229+ if image .dtype == torch .uint8 :
230+ image = image / 255.0
231+
232+ image = _rgb_to_hsv (image )
233+ h , s , v = image .unbind (dim = - 3 )
234+ h .add_ (hue_factor ).remainder_ (1.0 )
235+ image = torch .stack ((h , s , v ), dim = - 3 )
236+ image_hue_adj = _hsv_to_rgb (image )
237+
238+ if orig_dtype == torch .uint8 :
239+ image_hue_adj = image_hue_adj .mul_ (255.0 ).to (dtype = orig_dtype )
240+
241+ return image_hue_adj
242+
243+
147244adjust_hue_image_pil = _FP .adjust_hue
148245
149246
0 commit comments