From 838ad6ccf6ef7032485671a54c1570876e366dc5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 4 Sep 2024 11:38:51 +0100 Subject: [PATCH] Allow decoding functions to accept the mode parameter as a string (#8627) --- test/test_image.py | 11 +++++++++++ torchvision/io/image.py | 25 +++++++++++++++++++------ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index c817b7b831c..624dc57fada 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1065,5 +1065,16 @@ def test_decode_image_path(input_type, scripted): decode_fun(input) +def test_mode_str(): + # Make sure decode_image supports string modes. We just test decode_image, + # not all of the decoding functions, but they should all support that too. + # Torchscript fails when passing strings, which is expected. + path = next(get_images(IMAGE_ROOT, ".png")) + assert decode_image(path, mode="RGB").shape[0] == 3 + assert decode_image(path, mode="rGb").shape[0] == 3 + assert decode_image(path, mode="GRAY").shape[0] == 1 + assert decode_image(path, mode="RGBA").shape[0] == 4 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 8805846df23..8a2281946a9 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -40,6 +40,7 @@ class ImageReadMode(Enum): GRAY_ALPHA = 2 RGB = 3 RGB_ALPHA = 4 + RGBA = RGB_ALPHA # Alias for convenience def read_file(path: str) -> torch.Tensor: @@ -92,7 +93,7 @@ def decode_png( Args: input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the PNG image. - mode (ImageReadMode): the read mode used for optionally + mode (str or ImageReadMode): the read mode used for optionally converting the image. Default: ``ImageReadMode.UNCHANGED``. See `ImageReadMode` class for more information on various available modes. @@ -104,6 +105,8 @@ def decode_png( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_png) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation) return output @@ -168,7 +171,7 @@ def decode_jpeg( input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing the raw bytes of the JPEG image. The tensor(s) must be on CPU, regardless of the ``device`` parameter. - mode (ImageReadMode): the read mode used for optionally + mode (str or ImageReadMode): the read mode used for optionally converting the image(s). The supported modes are: ``ImageReadMode.UNCHANGED``, ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB`` Default: ``ImageReadMode.UNCHANGED``. @@ -198,6 +201,8 @@ def decode_jpeg( _log_api_usage_once(decode_jpeg) if isinstance(device, str): device = torch.device(device) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] if isinstance(input, list): if len(input) == 0: @@ -298,7 +303,7 @@ def decode_image( input (Tensor or str or ``pathlib.Path``): The image to decode. If a tensor is passed, it must be one dimensional uint8 tensor containing the raw bytes of the image. Otherwise, this must be a path to the image file. - mode (ImageReadMode): the read mode used for optionally converting the image. + mode (str or ImageReadMode): the read mode used for optionally converting the image. Default: ``ImageReadMode.UNCHANGED``. See ``ImageReadMode`` class for more information on various available modes. Only applies to JPEG and PNG images. @@ -312,6 +317,8 @@ def decode_image( _log_api_usage_once(decode_image) if not isinstance(input, torch.Tensor): input = read_file(str(input)) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation) return output @@ -360,7 +367,7 @@ def decode_webp( Args: input (Tensor[1]): a one dimensional contiguous uint8 tensor containing the raw bytes of the WEBP image. - mode (ImageReadMode): The read mode used for optionally + mode (str or ImageReadMode): The read mode used for optionally converting the image color space. Default: ``ImageReadMode.UNCHANGED``. Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. @@ -369,6 +376,8 @@ def decode_webp( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_webp) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] return torch.ops.image.decode_webp(input, mode.value) @@ -389,7 +398,7 @@ def _decode_avif( Args: input (Tensor[1]): a one dimensional contiguous uint8 tensor containing the raw bytes of the AVIF image. - mode (ImageReadMode): The read mode used for optionally + mode (str or ImageReadMode): The read mode used for optionally converting the image color space. Default: ``ImageReadMode.UNCHANGED``. Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. @@ -398,6 +407,8 @@ def _decode_avif( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(_decode_avif) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] return torch.ops.image.decode_avif(input, mode.value) @@ -415,7 +426,7 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN Args: input (Tensor[1]): a one dimensional contiguous uint8 tensor containing the raw bytes of the HEIC image. - mode (ImageReadMode): The read mode used for optionally + mode (str or ImageReadMode): The read mode used for optionally converting the image color space. Default: ``ImageReadMode.UNCHANGED``. Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. @@ -424,4 +435,6 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(_decode_heic) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] return torch.ops.image.decode_heic(input, mode.value)