Skip to content

Commit

Permalink
Revert "assume that integer images are [0, 255] in equalize (pytorch#…
Browse files Browse the repository at this point in the history
…6859)"

This reverts commit 436ff9a.
  • Loading branch information
pmeier committed Oct 28, 2022
1 parent 6895f71 commit c0236fc
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,18 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image

# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
# unfeasible for `torch.int64`.
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base.
output_dtype = image.dtype
if image.is_floating_point():
# Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it
# slower and more complicated to implement than a simple conversion and a fast histogram implementation for
# integers.
image = convert_dtype_image_tensor(image, torch.uint8)
image = convert_dtype_image_tensor(image, torch.uint8)

# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
# corresponds to adding 1 to index 127 in the histogram.
Expand Down

0 comments on commit c0236fc

Please sign in to comment.