From 938c8372e1a7696a59b398454c48ef7a28a1017c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 12 Oct 2022 13:46:20 +0000 Subject: [PATCH 1/2] [proto] Performance improvements for equalize op --- .../prototype/transforms/functional/_color.py | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 63fa8a28cfe..49c044ff78c 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -207,7 +207,7 @@ def _scale_channel(img_chan: torch.Tensor) -> torch.Tensor: return lut[img_chan.to(torch.int64)] -def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: +def equalize_image_tensor_slow(image: torch.Tensor) -> torch.Tensor: if image.dtype != torch.uint8: raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") @@ -223,6 +223,55 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: return torch.stack([_scale_channel(x) for x in image.view(-1, height, width)]).view(image.shape) +def _equalize_image_tensor_vec(img): + # input img shape should be [N, H, W] + shape = img.shape + # Compute image histogram: + flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W] + hist = flat_img.new_zeros(shape[0], 256) + hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img)) + + # Compute image cdf + chist = hist.cumsum_(dim=1) + # Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255 + # Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax() + idx = chist.argmax(dim=1).sub_(1) + # If histogram is degenerate (hist of zero image), index is -1 + neg_idx_mask = idx < 0 + idx.clamp_(min=0) + step = chist.gather(dim=1, index=idx.unsqueeze(1)) + step[neg_idx_mask] = 0 + step.div_(255, rounding_mode="floor") + + # Compute batched Look-up-table: + # Necessary to avoid an integer division by zero, which raises + clamped_step = step.clamp(min=1) + chist.add_(torch.div(step, 2, rounding_mode="floor")) \ + .div_(clamped_step, rounding_mode="floor") \ + .clamp_(0, 255) + lut = chist.to(torch.uint8) # [N, 256] + + # Pad lut with zeros + zeros = lut.new_zeros((1, 1)).expand(shape[0], 1) + lut = torch.cat([zeros, lut[:, :-1]], dim=1) + + return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).view_as(img)) + + +def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: + if image.dtype != torch.uint8: + raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") + + num_channels, height, width = get_dimensions_image_tensor(image) + if num_channels not in (1, 3): + raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") + + if image.numel() == 0: + return image + + return _equalize_image_tensor_vec(image.view(-1, height, width)).view(image.shape) + + equalize_image_pil = _FP.equalize From 2daa1acced28de25e09df21a76738fff16f55322 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 12 Oct 2022 14:23:13 +0000 Subject: [PATCH 2/2] Added tests --- test/test_prototype_transforms_functional.py | 11 +++++ .../prototype/transforms/functional/_color.py | 48 ++----------------- 2 files changed, 14 insertions(+), 45 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 982d776bdd0..34291611d8d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1037,3 +1037,14 @@ def test_to_image_pil(inpt, mode): assert isinstance(output, PIL.Image.Image) assert np.asarray(inpt).sum() == np.asarray(output).sum() + + +def test_equalize_image_tensor_edge_cases(): + inpt = torch.zeros(3, 200, 200, dtype=torch.uint8) + output = F.equalize_image_tensor(inpt) + torch.testing.assert_close(inpt, output) + + inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8) + inpt[..., 100:, 100:] = 1 + output = F.equalize_image_tensor(inpt) + assert output.unique().tolist() == [0, 255] diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 49c044ff78c..7cbf8885ca9 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -183,51 +183,11 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return autocontrast_image_pil(inpt) -def _scale_channel(img_chan: torch.Tensor) -> torch.Tensor: - # TODO: we should expect bincount to always be faster than histc, but this - # isn't always the case. Once - # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if - # block and only use bincount. - if img_chan.is_cuda: - hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) - else: - hist = torch.bincount(img_chan.view(-1), minlength=256) - - nonzero_hist = hist[hist != 0] - step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor") - if step == 0: - return img_chan - - lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor") - # Doing inplace clamp and converting lut to uint8 improves perfs - lut.clamp_(0, 255) - lut = lut.to(torch.uint8) - lut = torch.nn.functional.pad(lut[:-1], [1, 0]) - - return lut[img_chan.to(torch.int64)] - - -def equalize_image_tensor_slow(image: torch.Tensor) -> torch.Tensor: - if image.dtype != torch.uint8: - raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") - - num_channels, height, width = get_dimensions_image_tensor(image) - if num_channels not in (1, 3): - raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") - - if image.numel() == 0: - return image - elif image.ndim == 2: - return _scale_channel(image) - else: - return torch.stack([_scale_channel(x) for x in image.view(-1, height, width)]).view(image.shape) - - -def _equalize_image_tensor_vec(img): +def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor: # input img shape should be [N, H, W] shape = img.shape # Compute image histogram: - flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W] + flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W] hist = flat_img.new_zeros(shape[0], 256) hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img)) @@ -246,9 +206,7 @@ def _equalize_image_tensor_vec(img): # Compute batched Look-up-table: # Necessary to avoid an integer division by zero, which raises clamped_step = step.clamp(min=1) - chist.add_(torch.div(step, 2, rounding_mode="floor")) \ - .div_(clamped_step, rounding_mode="floor") \ - .clamp_(0, 255) + chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255) lut = chist.to(torch.uint8) # [N, 256] # Pad lut with zeros