From 94e918c68572b02c8165d65716cdd48a720c8b60 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 17 Oct 2022 20:58:22 +0000 Subject: [PATCH 01/11] WIP --- .../prototype/transforms/functional/_color.py | 61 +++++++++++++++++-- torchvision/transforms/functional_tensor.py | 7 +-- 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 49a769e04e0..43f03ebab87 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,9 +2,32 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._meta import get_dimensions_image_tensor +from ._meta import get_dimensions_image_tensor, get_num_channels_image_tensor + + +def _blend(img1: torch.Tensor, img2: torch.Tensor, ratio: float) -> torch.Tensor: + dtype = img1.dtype + fp = dtype.is_floating_point + bound = 1.0 if fp else 255.0 + if not fp and img2.is_cuda: + img2 = img2 * (1.0 - ratio) + else: + if not fp: + img2 = img2.to(torch.float32) + img2.mul_(1.0 - ratio) + img2.add_(img1, alpha=ratio).clamp_(0, bound) + return img2 if fp else img2.to(dtype) + + +def adjust_brightness_image_tensor(img: torch.Tensor, brightness_factor: float) -> torch.Tensor: + if brightness_factor < 0: + raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") + + _FT._assert_channels(img, [1, 3]) + + return _blend(img, torch.zeros_like(img, dtype=torch.float32), brightness_factor) + -adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness @@ -21,7 +44,20 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) -adjust_saturation_image_tensor = _FT.adjust_saturation +def adjust_saturation_image_tensor(img: torch.Tensor, saturation_factor: float) -> torch.Tensor: + if saturation_factor < 0: + raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") + + c = get_num_channels_image_tensor(img) + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + + if c[0] == 1: # Match PIL behaviour + return img + + return _blend(img, _FT.rgb_to_grayscale(img), saturation_factor) + + adjust_saturation_image_pil = _FP.adjust_saturation @@ -38,7 +74,22 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) -adjust_contrast_image_tensor = _FT.adjust_contrast +def adjust_contrast_image_tensor(img: torch.Tensor, contrast_factor: float) -> torch.Tensor: + if contrast_factor < 0: + raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") + + c = get_num_channels_image_tensor(img) + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + if c == 3: + mean = torch.mean(_FT.rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) + else: + mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True) + + return _blend(img, mean, contrast_factor) + + adjust_contrast_image_pil = _FP.adjust_contrast @@ -74,7 +125,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) else: needs_unsquash = False - output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) + output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) if needs_unsquash: output = output.reshape(shape) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4944c75fab8..ca641faf161 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -816,12 +816,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: kernel /= kernel.sum() kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( - img, - [ - kernel.dtype, - ], - ) + result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype]) result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype) From fc4f237d34830b1336f89e9d253ebae6d80bc44b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 18 Oct 2022 09:03:03 +0000 Subject: [PATCH 02/11] _blend optim v1 --- torchvision/prototype/transforms/functional/_color.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 43f03ebab87..de7ded3fe52 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -7,7 +7,7 @@ def _blend(img1: torch.Tensor, img2: torch.Tensor, ratio: float) -> torch.Tensor: dtype = img1.dtype - fp = dtype.is_floating_point + fp = img1.is_floating_point() bound = 1.0 if fp else 255.0 if not fp and img2.is_cuda: img2 = img2 * (1.0 - ratio) @@ -52,10 +52,10 @@ def adjust_saturation_image_tensor(img: torch.Tensor, saturation_factor: float) if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") - if c[0] == 1: # Match PIL behaviour + if c == 1: # Match PIL behaviour return img - return _blend(img, _FT.rgb_to_grayscale(img), saturation_factor) + return _blend(img, _FT.rgb_to_grayscale(img).expand_as(img).clone(), saturation_factor) adjust_saturation_image_pil = _FP.adjust_saturation @@ -87,7 +87,7 @@ def adjust_contrast_image_tensor(img: torch.Tensor, contrast_factor: float) -> t else: mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True) - return _blend(img, mean, contrast_factor) + return _blend(img, mean.expand_as(img).clone(), contrast_factor) adjust_contrast_image_pil = _FP.adjust_contrast From 58eec29a7ea80c91244955c727ae3b4dfb421d5b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 18 Oct 2022 09:46:38 +0000 Subject: [PATCH 03/11] _blend and color ops optims: v2 --- .../prototype/transforms/functional/_color.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index de7ded3fe52..a5423b7f7a6 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -6,17 +6,11 @@ def _blend(img1: torch.Tensor, img2: torch.Tensor, ratio: float) -> torch.Tensor: - dtype = img1.dtype + ratio = float(ratio) fp = img1.is_floating_point() bound = 1.0 if fp else 255.0 - if not fp and img2.is_cuda: - img2 = img2 * (1.0 - ratio) - else: - if not fp: - img2 = img2.to(torch.float32) - img2.mul_(1.0 - ratio) - img2.add_(img1, alpha=ratio).clamp_(0, bound) - return img2 if fp else img2.to(dtype) + output = img1.mul(ratio).add_(img2, alpha=(1.0 - ratio)).clamp_(0, bound) + return output if fp else output.to(img1.dtype) def adjust_brightness_image_tensor(img: torch.Tensor, brightness_factor: float) -> torch.Tensor: @@ -25,7 +19,10 @@ def adjust_brightness_image_tensor(img: torch.Tensor, brightness_factor: float) _FT._assert_channels(img, [1, 3]) - return _blend(img, torch.zeros_like(img, dtype=torch.float32), brightness_factor) + fp = img.is_floating_point() + bound = 1.0 if fp else 255.0 + output = img.mul(brightness_factor).clamp_(0, bound) + return output if fp else output.to(img.dtype) adjust_brightness_image_pil = _FP.adjust_brightness @@ -55,7 +52,7 @@ def adjust_saturation_image_tensor(img: torch.Tensor, saturation_factor: float) if c == 1: # Match PIL behaviour return img - return _blend(img, _FT.rgb_to_grayscale(img).expand_as(img).clone(), saturation_factor) + return _blend(img, _FT.rgb_to_grayscale(img), saturation_factor) adjust_saturation_image_pil = _FP.adjust_saturation @@ -87,7 +84,7 @@ def adjust_contrast_image_tensor(img: torch.Tensor, contrast_factor: float) -> t else: mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True) - return _blend(img, mean.expand_as(img).clone(), contrast_factor) + return _blend(img, mean, contrast_factor) adjust_contrast_image_pil = _FP.adjust_contrast From b7b5178c5bf9406536bf9cf59f186a96ec8dfd91 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 18 Oct 2022 12:36:40 +0000 Subject: [PATCH 04/11] updated a/r tol and configs to make tests pass --- test/test_prototype_transforms_consistency.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 212755068d9..4b18b711c08 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -253,9 +253,10 @@ def __init__( legacy_transforms.RandomAdjustSharpness, [ ArgsKwargs(p=0, sharpness_factor=0.5), - ArgsKwargs(p=1, sharpness_factor=0.3), + ArgsKwargs(p=1, sharpness_factor=0.2), ArgsKwargs(p=1, sharpness_factor=0.99), ], + closeness_kwargs={"atol": 1e-6, "rtol": 1e-6}, ), ConsistencyConfig( prototype_transforms.RandomGrayscale, @@ -305,8 +306,9 @@ def __init__( ArgsKwargs(saturation=(0.8, 0.9)), ArgsKwargs(hue=0.3), ArgsKwargs(hue=(-0.1, 0.2)), - ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3), + ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6), ], + closeness_kwargs={"atol": 1e-5, "rtol": 1e-5}, ), *[ ConsistencyConfig( From 2a5e4d88ce72c4226be717c96478ee3dd4167087 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 19 Oct 2022 12:33:18 +0000 Subject: [PATCH 05/11] Loose a/r tolerance in AA tests --- test/test_prototype_transforms_consistency.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 58cabde1d1a..7d72463260e 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -755,7 +755,7 @@ def test_randaug(self, inpt, interpolation, mocker): expected_output = t_ref(inpt) output = t(inpt) - assert_equal(expected_output, output) + assert_close(expected_output, output, atol=1, rtol=0.1) @pytest.mark.parametrize( "inpt", @@ -803,7 +803,7 @@ def test_trivial_aug(self, inpt, interpolation, mocker): expected_output = t_ref(inpt) output = t(inpt) - assert_equal(expected_output, output) + assert_close(expected_output, output, atol=1, rtol=0.1) @pytest.mark.parametrize( "inpt", From a170513047d25ae2fc9a29837464a28247c7c549 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 19 Oct 2022 14:00:27 +0000 Subject: [PATCH 06/11] Use custom rgb_to_grayscale --- .../prototype/transforms/functional/_color.py | 6 ++--- .../prototype/transforms/functional/_meta.py | 26 ++++++++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index a5423b7f7a6..242347dab11 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,7 +2,7 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._meta import get_dimensions_image_tensor, get_num_channels_image_tensor +from ._meta import get_dimensions_image_tensor, get_num_channels_image_tensor, rgb_to_grayscale def _blend(img1: torch.Tensor, img2: torch.Tensor, ratio: float) -> torch.Tensor: @@ -52,7 +52,7 @@ def adjust_saturation_image_tensor(img: torch.Tensor, saturation_factor: float) if c == 1: # Match PIL behaviour return img - return _blend(img, _FT.rgb_to_grayscale(img), saturation_factor) + return _blend(img, rgb_to_grayscale(img), saturation_factor) adjust_saturation_image_pil = _FP.adjust_saturation @@ -80,7 +80,7 @@ def adjust_contrast_image_tensor(img: torch.Tensor, contrast_factor: float) -> t raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") dtype = img.dtype if torch.is_floating_point(img) else torch.float32 if c == 3: - mean = torch.mean(_FT.rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) + mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) else: mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 2903d73ce95..cb360ecb95a 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -184,7 +184,31 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: return grayscale.repeat(repeats) -_rgb_to_gray = _FT.rgb_to_grayscale +def rgb_to_grayscale(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + if image.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {image.ndim}") + + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + + if num_output_channels not in (1, 3): + raise ValueError("num_output_channels should be either 1 or 3") + + if c == 3: + r, g, b = image.unbind(dim=-3) + l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114) + l_img = l_img.to(image.dtype).unsqueeze(dim=-3) + else: + l_img = image.clone() + + if num_output_channels == 3: + return l_img.expand(image.shape) + + return l_img + + +_rgb_to_gray = rgb_to_grayscale def convert_color_space_image_tensor( From 0b55072c5188b7182e9b2d8f2fe4be4cb1e62775 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 19 Oct 2022 20:34:23 +0000 Subject: [PATCH 07/11] Renamed img -> image --- .../prototype/transforms/functional/_color.py | 54 +++++++++---------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 242347dab11..24d9fd0f494 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -5,24 +5,24 @@ from ._meta import get_dimensions_image_tensor, get_num_channels_image_tensor, rgb_to_grayscale -def _blend(img1: torch.Tensor, img2: torch.Tensor, ratio: float) -> torch.Tensor: +def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: ratio = float(ratio) - fp = img1.is_floating_point() + fp = image1.is_floating_point() bound = 1.0 if fp else 255.0 - output = img1.mul(ratio).add_(img2, alpha=(1.0 - ratio)).clamp_(0, bound) - return output if fp else output.to(img1.dtype) + output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound) + return output if fp else output.to(image1.dtype) -def adjust_brightness_image_tensor(img: torch.Tensor, brightness_factor: float) -> torch.Tensor: +def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") - _FT._assert_channels(img, [1, 3]) + _FT._assert_channels(image, [1, 3]) - fp = img.is_floating_point() + fp = image.is_floating_point() bound = 1.0 if fp else 255.0 - output = img.mul(brightness_factor).clamp_(0, bound) - return output if fp else output.to(img.dtype) + output = image.mul(brightness_factor).clamp_(0, bound) + return output if fp else output.to(image.dtype) adjust_brightness_image_pil = _FP.adjust_brightness @@ -41,18 +41,18 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) -def adjust_saturation_image_tensor(img: torch.Tensor, saturation_factor: float) -> torch.Tensor: +def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") - c = get_num_channels_image_tensor(img) + c = get_num_channels_image_tensor(image) if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") if c == 1: # Match PIL behaviour - return img + return image - return _blend(img, rgb_to_grayscale(img), saturation_factor) + return _blend(image, rgb_to_grayscale(image), saturation_factor) adjust_saturation_image_pil = _FP.adjust_saturation @@ -71,20 +71,16 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) -def adjust_contrast_image_tensor(img: torch.Tensor, contrast_factor: float) -> torch.Tensor: +def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: if contrast_factor < 0: raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") - c = get_num_channels_image_tensor(img) + c = get_num_channels_image_tensor(image) if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - if c == 3: - mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) - else: - mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True) - - return _blend(img, mean, contrast_factor) + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + mean = torch.mean((rgb_to_grayscale(image) if c == 3 else image).to(dtype), dim=(-3, -2, -1), keepdim=True) + return _blend(image, mean, contrast_factor) adjust_contrast_image_pil = _FP.adjust_contrast @@ -231,13 +227,13 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return autocontrast_image_pil(inpt) -def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor: - # input img shape should be [N, H, W] - shape = img.shape +def _equalize_image_tensor_vec(image: torch.Tensor) -> torch.Tensor: + # input image shape should be [N, H, W] + shape = image.shape # Compute image histogram: - flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W] - hist = flat_img.new_zeros(shape[0], 256) - hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img)) + flat_image = image.flatten(start_dim=1).to(torch.long) # -> [N, H * W] + hist = flat_image.new_zeros(shape[0], 256) + hist.scatter_add_(dim=1, index=flat_image, src=flat_image.new_ones(1).expand_as(flat_image)) # Compute image cdf chist = hist.cumsum_(dim=1) @@ -261,7 +257,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor: zeros = lut.new_zeros((1, 1)).expand(shape[0], 1) lut = torch.cat([zeros, lut[:, :-1]], dim=1) - return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).reshape_as(img)) + return torch.where((step == 0).unsqueeze(-1), image, lut.gather(dim=1, index=flat_image).reshape_as(image)) def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: From b7fdd3903c787feb208362ee4a2eb194863e033c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 19 Oct 2022 20:39:39 +0000 Subject: [PATCH 08/11] nit code update --- torchvision/prototype/transforms/functional/_color.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 24d9fd0f494..b80eb040f5d 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -79,7 +79,8 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - mean = torch.mean((rgb_to_grayscale(image) if c == 3 else image).to(dtype), dim=(-3, -2, -1), keepdim=True) + grayscale_image = _FT.rgb_to_grayscale(image) if c == 3 else image + mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True) return _blend(image, mean, contrast_factor) From 41179579fa0c60a4bdc1f8c1f5883153744694b1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 20 Oct 2022 09:13:44 +0000 Subject: [PATCH 09/11] PR review --- .../prototype/transforms/functional/_color.py | 6 ++-- .../prototype/transforms/functional/_meta.py | 28 +++---------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index b80eb040f5d..ae07cc0056d 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,7 +2,7 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._meta import get_dimensions_image_tensor, get_num_channels_image_tensor, rgb_to_grayscale +from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: @@ -52,7 +52,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float if c == 1: # Match PIL behaviour return image - return _blend(image, rgb_to_grayscale(image), saturation_factor) + return _blend(image, _rgb_to_gray(image), saturation_factor) adjust_saturation_image_pil = _FP.adjust_saturation @@ -79,7 +79,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - grayscale_image = _FT.rgb_to_grayscale(image) if c == 3 else image + grayscale_image = _rgb_to_gray(image) if c == 3 else image mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True) return _blend(image, mean, contrast_factor) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index cb360ecb95a..61a54f01cc9 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -184,33 +184,13 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: return grayscale.repeat(repeats) -def rgb_to_grayscale(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: - if image.ndim < 3: - raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {image.ndim}") - - c = image.shape[-3] - if c not in [1, 3]: - raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") - - if num_output_channels not in (1, 3): - raise ValueError("num_output_channels should be either 1 or 3") - - if c == 3: - r, g, b = image.unbind(dim=-3) - l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114) - l_img = l_img.to(image.dtype).unsqueeze(dim=-3) - else: - l_img = image.clone() - - if num_output_channels == 3: - return l_img.expand(image.shape) - +def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor: + r, g, b = image.unbind(dim=-3) + l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114) + l_img = l_img.to(image.dtype).unsqueeze(dim=-3) return l_img -_rgb_to_gray = rgb_to_grayscale - - def convert_color_space_image_tensor( image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True ) -> torch.Tensor: From a82cf8c739d02acd9868ebee4b8b99d101c3e45e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 20 Oct 2022 10:41:46 +0000 Subject: [PATCH 10/11] adjust_contrast convert to float32 earlier --- torchvision/prototype/transforms/functional/_color.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index ae07cc0056d..65911248e5a 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -79,8 +79,8 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - grayscale_image = _rgb_to_gray(image) if c == 3 else image - mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True) + grayscale_image = _rgb_to_gray(image.to(dtype)) if c == 3 else image.to(dtype) + mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True) return _blend(image, mean, contrast_factor) From f19edc97ce1089ff3100006058e6b68826f75983 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 20 Oct 2022 19:59:44 +0000 Subject: [PATCH 11/11] Revert "adjust_contrast convert to float32 earlier" This reverts commit a82cf8c739d02acd9868ebee4b8b99d101c3e45e. --- torchvision/prototype/transforms/functional/_color.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 65911248e5a..ae07cc0056d 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -79,8 +79,8 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - grayscale_image = _rgb_to_gray(image.to(dtype)) if c == 3 else image.to(dtype) - mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True) + grayscale_image = _rgb_to_gray(image) if c == 3 else image + mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True) return _blend(image, mean, contrast_factor)