From 3dd2e3d84e920640029dd3718eed4a7c8a1ab38c Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 24 Oct 2022 11:16:11 +0200 Subject: [PATCH] [proto] Speed improvement for autocontrast op (#6811) * WIP * Updates to speed up autocontrast Co-authored-by: Vasilis Vryniotis --- .../prototype/transforms/functional/_color.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 68b52fff637..7bf412aaf99 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -211,7 +211,34 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp return solarize_image_pil(inpt, threshold=threshold) -autocontrast_image_tensor = _FT.autocontrast +def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: + + if not (isinstance(image, torch.Tensor)): + raise TypeError("Input img should be Tensor image") + + c = get_num_channels_image_tensor(image) + + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + + if image.numel() == 0: + # exit earlier on empty images + return image + + bound = 1.0 if image.is_floating_point() else 255.0 + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + + minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype) + maximum = image.amax(dim=(-2, -1), keepdim=True).to(dtype) + + scale = bound / (maximum - minimum) + eq_idxs = maximum == minimum + minimum[eq_idxs] = 0.0 + scale[eq_idxs] = 1.0 + + return (image - minimum).mul_(scale).clamp_(0, bound).to(image.dtype) + + autocontrast_image_pil = _FP.autocontrast