Skip to content

Commit

Permalink
another round of perf improvements for equalize (#6776)
Browse files Browse the repository at this point in the history
* perf improvements for equalize

Co-authored-by: lezcano <lezcano-93@hotmail.com>

* improve reference tests

* add extensive comments and minor fixes to the kernel

* improve comments

Co-authored-by: lezcano <lezcano-93@hotmail.com>
  • Loading branch information
pmeier and lezcano authored Oct 21, 2022
1 parent 9f024a6 commit c041798
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 36 deletions.
40 changes: 38 additions & 2 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
get_num_channels,
ImageLoader,
InfoBase,
make_bounding_box_loaders,
make_image_loader,
Expand Down Expand Up @@ -1359,9 +1361,43 @@ def sample_inputs_equalize_image_tensor():


def reference_inputs_equalize_image_tensor():
for image_loader in make_image_loaders(
extra_dims=[()], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# the information gain is low if we already provide something really close to the expected value.
spatial_size = (256, 256)
for fn, color_space in itertools.product(
[
*[
lambda shape, dtype, device, low=low, high=high: torch.randint(
low, high, shape, dtype=dtype, device=device
)
for low, high in [
(0, 1),
(255, 256),
(0, 64),
(64, 192),
(192, 256),
]
],
*[
lambda shape, dtype, device, alpha=alpha, beta=beta: torch.distributions.Beta(alpha, beta)
.sample(shape)
.mul_(255)
.round_()
.to(dtype=dtype, device=device)
for alpha, beta in [
(0.5, 0.5),
(2, 2),
(2, 5),
(5, 2),
]
],
],
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
):
image_loader = ImageLoader(
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=torch.uint8, color_space=color_space
)
yield ArgsKwargs(image_loader)


Expand Down
88 changes: 54 additions & 34 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,39 +228,6 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return autocontrast_image_pil(inpt)


def _equalize_image_tensor_vec(image: torch.Tensor) -> torch.Tensor:
# input image shape should be [N, H, W]
shape = image.shape
# Compute image histogram:
flat_image = image.flatten(start_dim=1).to(torch.long) # -> [N, H * W]
hist = flat_image.new_zeros(shape[0], 256)
hist.scatter_add_(dim=1, index=flat_image, src=flat_image.new_ones(1).expand_as(flat_image))

# Compute image cdf
chist = hist.cumsum_(dim=1)
# Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
# Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
idx = chist.argmax(dim=1).sub_(1)
# If histogram is degenerate (hist of zero image), index is -1
neg_idx_mask = idx < 0
idx.clamp_(min=0)
step = chist.gather(dim=1, index=idx.unsqueeze(1))
step[neg_idx_mask] = 0
step.div_(255, rounding_mode="floor")

# Compute batched Look-up-table:
# Necessary to avoid an integer division by zero, which raises
clamped_step = step.clamp(min=1)
chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255)
lut = chist.to(torch.uint8) # [N, 256]

# Pad lut with zeros
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
lut = torch.cat([zeros, lut[:, :-1]], dim=1)

return torch.where((step == 0).unsqueeze(-1), image, lut.gather(dim=1, index=flat_image).reshape_as(image))


def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}")
Expand All @@ -272,7 +239,60 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image

return _equalize_image_tensor_vec(image.reshape(-1, height, width)).reshape(image.shape)
batch_shape = image.shape[:-2]
flat_image = image.flatten(start_dim=-2).to(torch.long)

# The algorithm for histogram equalization is mirrored from PIL:
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385

# Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8
# images here and thus the values are already binned, the computation is trivial. The histogram is computed by using
# the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127
# in the histogram.
hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32)
hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image))
cum_hist = hist.cumsum(dim=-1)

# The simplest form of lookup-table (LUT) that also achieves histogram equalization is
# `lut = cum_hist / flat_image.shape[-1] * 255`
# However, PIL uses a more elaborate scheme:
# `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255`

# The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum
# value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but
# rather the maximum value in the image, which might be or not be 255.
index = cum_hist.argmax(dim=-1)
num_non_max_pixels = flat_image.shape[-1] - hist.gather(dim=-1, index=index.unsqueeze_(-1))

# This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies
# to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the
# division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison.
step = num_non_max_pixels.div_(255, rounding_mode="floor")

# Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as
# easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't,
# we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to
# pay the runtime cost for checking it every time.
no_equalization = step.eq(0).unsqueeze_(-1)

# `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the
# computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards.
cum_hist = cum_hist[..., :-1]
(
cum_hist.add_(step // 2)
# We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no
# effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is
# instead of equalized version.
.div_(step.clamp_(min=1), rounding_mode="floor")
# We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value
# range of uint8 images
.clamp_(0, 255)
)
lut = cum_hist.to(torch.uint8)
lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1)
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)

return torch.where(no_equalization, image, equalized_image)


equalize_image_pil = _FP.equalize
Expand Down

0 comments on commit c041798

Please sign in to comment.