Skip to content

Commit

Permalink
Restore _split_alpha.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Oct 25, 2022
1 parent 090a0d7 commit 2286120
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,12 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format)


def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.tensor_split(image, indices=(-1,), dim=-3)


def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
image, alpha = _split_alpha(image)
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported."
Expand Down Expand Up @@ -233,7 +237,7 @@ def convert_color_space_image_tensor(
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(_strip_alpha(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
image, alpha = _split_alpha(image)
return _add_alpha(_gray_to_rgb(image), alpha)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(image)
Expand All @@ -244,7 +248,7 @@ def convert_color_space_image_tensor(
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(_strip_alpha(image))
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
image, alpha = _split_alpha(image)
return _add_alpha(_rgb_to_gray(image), alpha)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
return _strip_alpha(image)
Expand Down

0 comments on commit 2286120

Please sign in to comment.