-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Histogram Equalization #1049
Comments
@fmassa , I would like to take this up if we want to add this. |
Looks like this histogram matching feature has already been requested before in #598 and was implemented in #796 but hasn't been merged. def torch_equalize(image):
"""Implements Equalize function from PIL using PyTorch ops based on:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352"""
def scale_channel(im, c):
"""Scale the data in the channel to implement equalize."""
im = im[:, :, c]
# Compute the histogram of the image channel.
histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
# For the purposes of computing the step, filter out the nonzeros.
nonzero_histo = torch.reshape(histo[histo != 0], [-1])
step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
def build_lut(histo, step):
# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
lut = (torch.cumsum(histo, 0) + (step // 2)) // step
# Shift lut, prepending with 0.
lut = torch.cat([torch.zeros(1), lut[:-1]])
# Clip the counts to be in range. This is done
# in the C code for image.point.
return torch.clamp(lut, 0, 255)
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
if step == 0:
result = im
else:
# can't index using 2d index. Have to flatten and then reshape
result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
result = result.reshape_as(im)
return result.type(torch.uint8)
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
s1 = scale_channel(image, 0)
s2 = scale_channel(image, 1)
s3 = scale_channel(image, 2)
image = torch.stack([s1, s2, s3], 2)
return image |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I have been using
cv2
to histogram equalize my images. I recently found thatPIL
has a function calledImageOps.equalize(image, mask=None)
which does the same thing. This transform has been really useful to me. Since this transform is implemented inPIL
, which is a supported backend, I was wondering if it would be a good addition totorchvision.transforms
.The text was updated successfully, but these errors were encountered: