Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proto] Performance improvements for equalize op #6757

Merged
merged 3 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,3 +1037,14 @@ def test_to_image_pil(inpt, mode):
assert isinstance(output, PIL.Image.Image)

assert np.asarray(inpt).sum() == np.asarray(output).sum()


def test_equalize_image_tensor_edge_cases():
inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
output = F.equalize_image_tensor(inpt)
torch.testing.assert_close(inpt, output)

inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
inpt[..., 100:, 100:] = 1
output = F.equalize_image_tensor(inpt)
assert output.unique().tolist() == [0, 255]
59 changes: 33 additions & 26 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,28 +183,37 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return autocontrast_image_pil(inpt)


def _scale_channel(img_chan: torch.Tensor) -> torch.Tensor:
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if img_chan.is_cuda:
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.view(-1), minlength=256)

nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
if step == 0:
return img_chan

lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
# Doing inplace clamp and converting lut to uint8 improves perfs
lut.clamp_(0, 255)
lut = lut.to(torch.uint8)
lut = torch.nn.functional.pad(lut[:-1], [1, 0])

return lut[img_chan.to(torch.int64)]
def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
# input img shape should be [N, H, W]
shape = img.shape
# Compute image histogram:
flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W]
hist = flat_img.new_zeros(shape[0], 256)
hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img))

# Compute image cdf
chist = hist.cumsum_(dim=1)
# Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
# Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
idx = chist.argmax(dim=1).sub_(1)
# If histogram is degenerate (hist of zero image), index is -1
neg_idx_mask = idx < 0
idx.clamp_(min=0)
step = chist.gather(dim=1, index=idx.unsqueeze(1))
step[neg_idx_mask] = 0
step.div_(255, rounding_mode="floor")

# Compute batched Look-up-table:
# Necessary to avoid an integer division by zero, which raises
clamped_step = step.clamp(min=1)
chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255)
lut = chist.to(torch.uint8) # [N, 256]

# Pad lut with zeros
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
lut = torch.cat([zeros, lut[:, :-1]], dim=1)

return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).view_as(img))


def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
Expand All @@ -217,10 +226,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:

if image.numel() == 0:
return image
elif image.ndim == 2:
return _scale_channel(image)
else:
return torch.stack([_scale_channel(x) for x in image.view(-1, height, width)]).view(image.shape)

return _equalize_image_tensor_vec(image.view(-1, height, width)).view(image.shape)


equalize_image_pil = _FP.equalize
Expand Down