From 7de5d1fef71e4182d6d11b04b546a43a71bc2847 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 21 Oct 2022 13:56:36 +0000 Subject: [PATCH 1/2] WIP --- .../prototype/transforms/functional/_color.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 68b52fff637..5754721100b 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -211,7 +211,43 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp return solarize_image_pil(inpt, threshold=threshold) -autocontrast_image_tensor = _FT.autocontrast +def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: + + if not (isinstance(image, torch.Tensor)): + raise TypeError("Input img should be Tensor image") + + c = get_num_channels_image_tensor(image) + + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + + # if img.ndim < 3: + # raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + + # _assert_channels(img, [1, 3]) + + bound = 1.0 if image.is_floating_point() else 255.0 + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + + # let's squash spatial dims into a single dim, such that torch.aminmax support it + shape = image.shape[:-2] + (image.shape[-2] * image.shape[-1], ) + print("shape:", shape, image.numel()) + + minimum = torch.amin(image.view(shape), dim=(-2, -1), keepdim=True) + maximum = torch.amax(image.view(shape), dim=(-2, -1), keepdim=True) + print(minimum, maximum) + + minimum, maximum = torch.aminmax(image.view(shape), dim=-1, keepdim=True) + minimum = minimum.to(dtype).unsqueeze(-1) + maximum = maximum.to(dtype).unsqueeze(-1) + scale = bound / (maximum - minimum) + eq_idxs = maximum == minimum + minimum[eq_idxs] = 0.0 + scale[eq_idxs] = 1.0 + + return ((image - minimum) * scale).clamp_(0, bound).to(image.dtype) + + autocontrast_image_pil = _FP.autocontrast From 03c987d92024147b4fa486e5916506dd663eb92b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 21 Oct 2022 14:26:53 +0000 Subject: [PATCH 2/2] Updates to speed up autocontrast --- .../prototype/transforms/functional/_color.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 5754721100b..7bf412aaf99 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -221,31 +221,22 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") - # if img.ndim < 3: - # raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") - - # _assert_channels(img, [1, 3]) + if image.numel() == 0: + # exit earlier on empty images + return image bound = 1.0 if image.is_floating_point() else 255.0 dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - # let's squash spatial dims into a single dim, such that torch.aminmax support it - shape = image.shape[:-2] + (image.shape[-2] * image.shape[-1], ) - print("shape:", shape, image.numel()) - - minimum = torch.amin(image.view(shape), dim=(-2, -1), keepdim=True) - maximum = torch.amax(image.view(shape), dim=(-2, -1), keepdim=True) - print(minimum, maximum) + minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype) + maximum = image.amax(dim=(-2, -1), keepdim=True).to(dtype) - minimum, maximum = torch.aminmax(image.view(shape), dim=-1, keepdim=True) - minimum = minimum.to(dtype).unsqueeze(-1) - maximum = maximum.to(dtype).unsqueeze(-1) scale = bound / (maximum - minimum) eq_idxs = maximum == minimum minimum[eq_idxs] = 0.0 scale[eq_idxs] = 1.0 - return ((image - minimum) * scale).clamp_(0, bound).to(image.dtype) + return (image - minimum).mul_(scale).clamp_(0, bound).to(image.dtype) autocontrast_image_pil = _FP.autocontrast