diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 133508f5f94..eb90508fa5c 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -13,6 +13,8 @@ from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, + get_num_channels, + ImageLoader, InfoBase, make_bounding_box_loaders, make_image_loader, @@ -1359,9 +1361,43 @@ def sample_inputs_equalize_image_tensor(): def reference_inputs_equalize_image_tensor(): - for image_loader in make_image_loaders( - extra_dims=[()], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8] + # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range. + # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one, + # the information gain is low if we already provide something really close to the expected value. + spatial_size = (256, 256) + for fn, color_space in itertools.product( + [ + *[ + lambda shape, dtype, device, low=low, high=high: torch.randint( + low, high, shape, dtype=dtype, device=device + ) + for low, high in [ + (0, 1), + (255, 256), + (0, 64), + (64, 192), + (192, 256), + ] + ], + *[ + lambda shape, dtype, device, alpha=alpha, beta=beta: torch.distributions.Beta(alpha, beta) + .sample(shape) + .mul_(255) + .round_() + .to(dtype=dtype, device=device) + for alpha, beta in [ + (0.5, 0.5), + (2, 2), + (2, 5), + (5, 2), + ] + ], + ], + [features.ColorSpace.GRAY, features.ColorSpace.RGB], ): + image_loader = ImageLoader( + fn, shape=(get_num_channels(color_space), *spatial_size), dtype=torch.uint8, color_space=color_space + ) yield ArgsKwargs(image_loader) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index ae07cc0056d..68b52fff637 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -228,39 +228,6 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return autocontrast_image_pil(inpt) -def _equalize_image_tensor_vec(image: torch.Tensor) -> torch.Tensor: - # input image shape should be [N, H, W] - shape = image.shape - # Compute image histogram: - flat_image = image.flatten(start_dim=1).to(torch.long) # -> [N, H * W] - hist = flat_image.new_zeros(shape[0], 256) - hist.scatter_add_(dim=1, index=flat_image, src=flat_image.new_ones(1).expand_as(flat_image)) - - # Compute image cdf - chist = hist.cumsum_(dim=1) - # Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255 - # Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax() - idx = chist.argmax(dim=1).sub_(1) - # If histogram is degenerate (hist of zero image), index is -1 - neg_idx_mask = idx < 0 - idx.clamp_(min=0) - step = chist.gather(dim=1, index=idx.unsqueeze(1)) - step[neg_idx_mask] = 0 - step.div_(255, rounding_mode="floor") - - # Compute batched Look-up-table: - # Necessary to avoid an integer division by zero, which raises - clamped_step = step.clamp(min=1) - chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255) - lut = chist.to(torch.uint8) # [N, 256] - - # Pad lut with zeros - zeros = lut.new_zeros((1, 1)).expand(shape[0], 1) - lut = torch.cat([zeros, lut[:, :-1]], dim=1) - - return torch.where((step == 0).unsqueeze(-1), image, lut.gather(dim=1, index=flat_image).reshape_as(image)) - - def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.dtype != torch.uint8: raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") @@ -272,7 +239,60 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: return image - return _equalize_image_tensor_vec(image.reshape(-1, height, width)).reshape(image.shape) + batch_shape = image.shape[:-2] + flat_image = image.flatten(start_dim=-2).to(torch.long) + + # The algorithm for histogram equalization is mirrored from PIL: + # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 + + # Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8 + # images here and thus the values are already binned, the computation is trivial. The histogram is computed by using + # the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127 + # in the histogram. + hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32) + hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image)) + cum_hist = hist.cumsum(dim=-1) + + # The simplest form of lookup-table (LUT) that also achieves histogram equalization is + # `lut = cum_hist / flat_image.shape[-1] * 255` + # However, PIL uses a more elaborate scheme: + # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255` + + # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum + # value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but + # rather the maximum value in the image, which might be or not be 255. + index = cum_hist.argmax(dim=-1) + num_non_max_pixels = flat_image.shape[-1] - hist.gather(dim=-1, index=index.unsqueeze_(-1)) + + # This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies + # to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the + # division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison. + step = num_non_max_pixels.div_(255, rounding_mode="floor") + + # Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as + # easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't, + # we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to + # pay the runtime cost for checking it every time. + no_equalization = step.eq(0).unsqueeze_(-1) + + # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the + # computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards. + cum_hist = cum_hist[..., :-1] + ( + cum_hist.add_(step // 2) + # We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no + # effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is + # instead of equalized version. + .div_(step.clamp_(min=1), rounding_mode="floor") + # We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value + # range of uint8 images + .clamp_(0, 255) + ) + lut = cum_hist.to(torch.uint8) + lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1) + equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) + + return torch.where(no_equalization, image, equalized_image) equalize_image_pil = _FP.equalize