22from torchvision .prototype import features
33from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
44
5- from ._meta import _rgb_to_gray , convert_dtype_image_tensor
5+ from ._meta import _num_value_bits , _rgb_to_gray , convert_dtype_image_tensor
66
77
88def _blend (image1 : torch .Tensor , image2 : torch .Tensor , ratio : float ) -> torch .Tensor :
99 ratio = float (ratio )
1010 fp = image1 .is_floating_point ()
11- bound = 1.0 if fp else 255.0
11+ bound = _FT . _max_value ( image1 . dtype )
1212 output = image1 .mul (ratio ).add_ (image2 , alpha = (1.0 - ratio )).clamp_ (0 , bound )
1313 return output if fp else output .to (image1 .dtype )
1414
@@ -20,7 +20,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
2020 _FT ._assert_channels (image , [1 , 3 ])
2121
2222 fp = image .is_floating_point ()
23- bound = 1.0 if fp else 255.0
23+ bound = _FT . _max_value ( image . dtype )
2424 output = image .mul (brightness_factor ).clamp_ (0 , bound )
2525 return output if fp else output .to (image .dtype )
2626
@@ -222,19 +222,15 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
222222 return image
223223
224224 orig_dtype = image .dtype
225- if image .dtype == torch .uint8 :
226- image = image / 255.0
225+ image = convert_dtype_image_tensor (image , torch .float32 )
227226
228227 image = _rgb_to_hsv (image )
229228 h , s , v = image .unbind (dim = - 3 )
230229 h .add_ (hue_factor ).remainder_ (1.0 )
231230 image = torch .stack ((h , s , v ), dim = - 3 )
232231 image_hue_adj = _hsv_to_rgb (image )
233232
234- if orig_dtype == torch .uint8 :
235- image_hue_adj = image_hue_adj .mul_ (255.0 ).to (dtype = orig_dtype )
236-
237- return image_hue_adj
233+ return convert_dtype_image_tensor (image_hue_adj , orig_dtype )
238234
239235
240236adjust_hue_image_pil = _FP .adjust_hue
@@ -289,14 +285,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
289285
290286
291287def posterize_image_tensor (image : torch .Tensor , bits : int ) -> torch .Tensor :
292- if bits > 8 :
293- return image
294-
295288 if image .is_floating_point ():
296289 levels = 1 << bits
297290 return image .mul (levels ).floor_ ().clamp_ (0 , levels - 1 ).div_ (levels )
298291 else :
299- mask = ((1 << bits ) - 1 ) << (8 - bits )
292+ num_value_bits = _num_value_bits (image .dtype )
293+ if bits >= num_value_bits :
294+ return image
295+
296+ mask = ((1 << bits ) - 1 ) << (num_value_bits - bits )
300297 return image & mask
301298
302299
@@ -317,8 +314,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
317314
318315
319316def solarize_image_tensor (image : torch .Tensor , threshold : float ) -> torch .Tensor :
320- bound = 1 if image .is_floating_point () else 255
321- if threshold > bound :
317+ if threshold > _FT ._max_value (image .dtype ):
322318 raise TypeError (f"Threshold should be less or equal the maximum value of the dtype, but got { threshold } " )
323319
324320 return torch .where (image >= threshold , invert_image_tensor (image ), image )
@@ -349,7 +345,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
349345 # exit earlier on empty images
350346 return image
351347
352- bound = 1.0 if image .is_floating_point () else 255.0
348+ bound = _FT . _max_value ( image .dtype )
353349 dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
354350
355351 minimum = image .amin (dim = (- 2 , - 1 ), keepdim = True ).to (dtype )
@@ -383,14 +379,18 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
383379 if image .numel () == 0 :
384380 return image
385381
382+ # 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
383+ # would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
384+ # `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
385+ # unfeasible for `torch.int64`.
386+ # 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
387+ # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
388+ # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
389+ # and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
390+ # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
391+ # by far the most common, we choose it as base.
386392 output_dtype = image .dtype
387- if image .is_floating_point ():
388- # Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
389- # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
390- # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it
391- # slower and more complicated to implement than a simple conversion and a fast histogram implementation for
392- # integers.
393- image = convert_dtype_image_tensor (image , torch .uint8 )
393+ image = convert_dtype_image_tensor (image , torch .uint8 )
394394
395395 # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
396396 # corresponds to adding 1 to index 127 in the histogram.
@@ -461,10 +461,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
461461
462462
463463def invert_image_tensor (image : torch .Tensor ) -> torch .Tensor :
464- if image .dtype == torch .uint8 :
464+ if image .is_floating_point ():
465+ return 1.0 - image # type: ignore[no-any-return]
466+ elif image .dtype == torch .uint8 :
465467 return image .bitwise_not ()
466- else :
467- return (1 if image .is_floating_point () else 255 ) - image # type: ignore[no-any-return]
468+ else : # signed integer dtypes
469+ # We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign
470+ return image .bitwise_xor ((1 << _num_value_bits (image .dtype )) - 1 )
468471
469472
470473invert_image_pil = _FP .invert
0 commit comments