Skip to content

Commit

Permalink
Allow decoding functions to accept the mode parameter as a string (#8627
Browse files Browse the repository at this point in the history
)
  • Loading branch information
NicolasHug authored Sep 4, 2024
1 parent d0ebeb5 commit 838ad6c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
11 changes: 11 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,5 +1065,16 @@ def test_decode_image_path(input_type, scripted):
decode_fun(input)


def test_mode_str():
# Make sure decode_image supports string modes. We just test decode_image,
# not all of the decoding functions, but they should all support that too.
# Torchscript fails when passing strings, which is expected.
path = next(get_images(IMAGE_ROOT, ".png"))
assert decode_image(path, mode="RGB").shape[0] == 3
assert decode_image(path, mode="rGb").shape[0] == 3
assert decode_image(path, mode="GRAY").shape[0] == 1
assert decode_image(path, mode="RGBA").shape[0] == 4


if __name__ == "__main__":
pytest.main([__file__])
25 changes: 19 additions & 6 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ImageReadMode(Enum):
GRAY_ALPHA = 2
RGB = 3
RGB_ALPHA = 4
RGBA = RGB_ALPHA # Alias for convenience


def read_file(path: str) -> torch.Tensor:
Expand Down Expand Up @@ -92,7 +93,7 @@ def decode_png(
Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the PNG image.
mode (ImageReadMode): the read mode used for optionally
mode (str or ImageReadMode): the read mode used for optionally
converting the image. Default: ``ImageReadMode.UNCHANGED``.
See `ImageReadMode` class for more information on various
available modes.
Expand All @@ -104,6 +105,8 @@ def decode_png(
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_png)
if isinstance(mode, str):
mode = ImageReadMode[mode.upper()]
output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
return output

Expand Down Expand Up @@ -168,7 +171,7 @@ def decode_jpeg(
input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
the raw bytes of the JPEG image. The tensor(s) must be on CPU,
regardless of the ``device`` parameter.
mode (ImageReadMode): the read mode used for optionally
mode (str or ImageReadMode): the read mode used for optionally
converting the image(s). The supported modes are: ``ImageReadMode.UNCHANGED``,
``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
Default: ``ImageReadMode.UNCHANGED``.
Expand Down Expand Up @@ -198,6 +201,8 @@ def decode_jpeg(
_log_api_usage_once(decode_jpeg)
if isinstance(device, str):
device = torch.device(device)
if isinstance(mode, str):
mode = ImageReadMode[mode.upper()]

if isinstance(input, list):
if len(input) == 0:
Expand Down Expand Up @@ -298,7 +303,7 @@ def decode_image(
input (Tensor or str or ``pathlib.Path``): The image to decode. If a
tensor is passed, it must be one dimensional uint8 tensor containing
the raw bytes of the image. Otherwise, this must be a path to the image file.
mode (ImageReadMode): the read mode used for optionally converting the image.
mode (str or ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes. Only applies to JPEG and PNG images.
Expand All @@ -312,6 +317,8 @@ def decode_image(
_log_api_usage_once(decode_image)
if not isinstance(input, torch.Tensor):
input = read_file(str(input))
if isinstance(mode, str):
mode = ImageReadMode[mode.upper()]
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
return output

Expand Down Expand Up @@ -360,7 +367,7 @@ def decode_webp(
Args:
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
the raw bytes of the WEBP image.
mode (ImageReadMode): The read mode used for optionally
mode (str or ImageReadMode): The read mode used for optionally
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
Expand All @@ -369,6 +376,8 @@ def decode_webp(
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
if isinstance(mode, str):
mode = ImageReadMode[mode.upper()]
return torch.ops.image.decode_webp(input, mode.value)


Expand All @@ -389,7 +398,7 @@ def _decode_avif(
Args:
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
the raw bytes of the AVIF image.
mode (ImageReadMode): The read mode used for optionally
mode (str or ImageReadMode): The read mode used for optionally
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
Expand All @@ -398,6 +407,8 @@ def _decode_avif(
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(_decode_avif)
if isinstance(mode, str):
mode = ImageReadMode[mode.upper()]
return torch.ops.image.decode_avif(input, mode.value)


Expand All @@ -415,7 +426,7 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
Args:
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
the raw bytes of the HEIC image.
mode (ImageReadMode): The read mode used for optionally
mode (str or ImageReadMode): The read mode used for optionally
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
Expand All @@ -424,4 +435,6 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(_decode_heic)
if isinstance(mode, str):
mode = ImageReadMode[mode.upper()]
return torch.ops.image.decode_heic(input, mode.value)

0 comments on commit 838ad6c

Please sign in to comment.