From 2286120be6d4af2a3c9b52b605d87611ec70fe06 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 25 Oct 2022 12:20:15 +0100 Subject: [PATCH] Restore `_split_alpha`. --- torchvision/prototype/transforms/functional/_meta.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 5e017848415..98f1963c1ea 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -183,8 +183,12 @@ def clamp_bounding_box( return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format) +def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.tensor_split(image, indices=(-1,), dim=-3) + + def _strip_alpha(image: torch.Tensor) -> torch.Tensor: - image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) + image, alpha = _split_alpha(image) if not torch.all(alpha == _FT._max_value(alpha.dtype)): raise RuntimeError( "Stripping the alpha channel if it contains values other than the max value is not supported." @@ -233,7 +237,7 @@ def convert_color_space_image_tensor( elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB: return _gray_to_rgb(_strip_alpha(image)) elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA: - image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) + image, alpha = _split_alpha(image) return _add_alpha(_gray_to_rgb(image), alpha) elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY: return _rgb_to_gray(image) @@ -244,7 +248,7 @@ def convert_color_space_image_tensor( elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY: return _rgb_to_gray(_strip_alpha(image)) elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA: - image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) + image, alpha = _split_alpha(image) return _add_alpha(_rgb_to_gray(image), alpha) elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB: return _strip_alpha(image)