Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Histogram Equalization #1049

Closed
dakshjotwani opened this issue Jun 25, 2019 · 3 comments
Closed

Histogram Equalization #1049

dakshjotwani opened this issue Jun 25, 2019 · 3 comments

Comments

@dakshjotwani
Copy link
Contributor

I have been using cv2 to histogram equalize my images. I recently found that PIL has a function called ImageOps.equalize(image, mask=None) which does the same thing. This transform has been really useful to me. Since this transform is implemented in PIL, which is a supported backend, I was wondering if it would be a good addition to torchvision.transforms.

@surgan12
Copy link
Contributor

surgan12 commented Jun 26, 2019

@fmassa , I would like to take this up if we want to add this.

@zakajd
Copy link

zakajd commented Aug 7, 2019

Looks like this histogram matching feature has already been requested before in #598 and was implemented in #796 but hasn't been merged.
I've implemented a much simplier version identical to the one in PIL.ImageOps.equalize.

def torch_equalize(image):
    """Implements Equalize function from PIL using PyTorch ops based on:
    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352"""
    def scale_channel(im, c):
        """Scale the data in the channel to implement equalize."""
        im = im[:, :, c]
        # Compute the histogram of the image channel.
        histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
        # For the purposes of computing the step, filter out the nonzeros.
        nonzero_histo = torch.reshape(histo[histo != 0], [-1])
        step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
        def build_lut(histo, step):
            # Compute the cumulative sum, shifting by step // 2
            # and then normalization by step.
            lut = (torch.cumsum(histo, 0) + (step // 2)) // step
            # Shift lut, prepending with 0.
            lut = torch.cat([torch.zeros(1), lut[:-1]]) 
            # Clip the counts to be in range.  This is done
            # in the C code for image.point.
            return torch.clamp(lut, 0, 255)

        # If step is zero, return the original image.  Otherwise, build
        # lut from the full histogram and step and then index from it.
        if step == 0:
            result = im
        else:
            # can't index using 2d index. Have to flatten and then reshape
            result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
            result = result.reshape_as(im)
        
        return result.type(torch.uint8)

    # Assumes RGB for now.  Scales each channel independently
    # and then stacks the result.
    s1 = scale_channel(image, 0)
    s2 = scale_channel(image, 1)
    s3 = scale_channel(image, 2)
    image = torch.stack([s1, s2, s3], 2)
    return image

@datumbox
Copy link
Contributor

datumbox commented Feb 5, 2021

Closing as this has been implemented at #3119 and #3123.

@datumbox datumbox closed this as completed Feb 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants