From 3f2c2dffb2931fa427688026c6052d8089d6756a Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 28 Aug 2024 11:34:36 +0000 Subject: [PATCH] 2024-08-28 nightly release (a59c93980d97f6216917415ae25f3ac88e64cbb4) --- test/test_image.py | 87 ++++++++++++++++++- torchvision/csrc/io/image/cpu/decode_avif.cpp | 43 ++++++--- torchvision/csrc/io/image/cpu/decode_avif.h | 5 +- .../csrc/io/image/cpu/decode_image.cpp | 4 +- torchvision/csrc/io/image/cpu/decode_webp.cpp | 48 ++++++++-- torchvision/csrc/io/image/cpu/decode_webp.h | 5 +- torchvision/csrc/io/image/image.cpp | 6 +- torchvision/io/image.py | 39 +++++++-- .../transforms/v2/functional/_color.py | 2 +- 9 files changed, 209 insertions(+), 30 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 5b0da3481ab..f1fe70135fe 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -875,7 +875,7 @@ def test_decode_gif_webp_errors(decode_fun): if decode_fun is decode_gif: expected_match = re.escape("DGifOpenFileName() failed - 103") elif decode_fun is decode_webp: - expected_match = "WebPDecodeRGB failed." + expected_match = "WebPGetFeatures failed." with pytest.raises(RuntimeError, match=expected_match): decode_fun(encoded_data) @@ -891,6 +891,31 @@ def test_decode_webp(decode_fun, scripted): assert img[None].is_contiguous(memory_format=torch.channels_last) +# This test is skipped because it requires webp images that we're not including +# within the repo. The test images were downloaded from the different pages of +# https://developers.google.com/speed/webp/gallery +# Note that converting an RGBA image to RGB leads to bad results because the +# transparent pixels aren't necessarily set to "black" or "white", they can be +# random stuff. This is consistent with PIL results. +@pytest.mark.skip(reason="Need to download test images first") +@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image)) +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize( + "mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None)) +) +@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp")) +def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename): + encoded_bytes = read_file(filename) + if scripted: + decode_fun = torch.jit.script(decode_fun) + img = decode_fun(encoded_bytes, mode=mode) + assert img[None].is_contiguous(memory_format=torch.channels_last) + + pil_img = Image.open(filename).convert(pil_mode) + from_pil = F.pil_to_tensor(pil_img) + assert_equal(img, from_pil) + + @pytest.mark.xfail(reason="AVIF support not enabled yet.") @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) @pytest.mark.parametrize("scripted", (False, True)) @@ -903,5 +928,65 @@ def test_decode_avif(decode_fun, scripted): assert img[None].is_contiguous(memory_format=torch.channels_last) +@pytest.mark.xfail(reason="AVIF support not enabled yet.") +# Note: decode_image fails because some of these files have a (valid) signature +# we don't recognize. We should probably use libmagic.... +# @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) +@pytest.mark.parametrize("decode_fun", (_decode_avif,)) +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize( + "mode, pil_mode", + ( + (ImageReadMode.RGB, "RGB"), + (ImageReadMode.RGB_ALPHA, "RGBA"), + (ImageReadMode.UNCHANGED, None), + ), +) +@pytest.mark.parametrize("filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif")) +def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename): + if "reversed_dimg_order" in str(filename): + # Pillow properly decodes this one, but we don't (order of parts of the + # image is wrong). This is due to a bug that was recently fixed in + # libavif. Hopefully this test will end up passing soon with a new + # libavif version https://github.com/AOMediaCodec/libavif/issues/2311 + pytest.xfail() + import pillow_avif # noqa + + encoded_bytes = read_file(filename) + if scripted: + decode_fun = torch.jit.script(decode_fun) + try: + img = decode_fun(encoded_bytes, mode=mode) + except RuntimeError as e: + if any( + s in str(e) + for s in ("BMFF parsing failed", "avifDecoderParse failed: ", "file contains more than one image") + ): + pytest.skip(reason="Expected failure, that's OK") + else: + raise e + assert img[None].is_contiguous(memory_format=torch.channels_last) + if mode == ImageReadMode.RGB: + assert img.shape[0] == 3 + if mode == ImageReadMode.RGB_ALPHA: + assert img.shape[0] == 4 + if img.dtype == torch.uint16: + img = F.to_dtype(img, dtype=torch.uint8, scale=True) + + from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode)) + if False: + from torchvision.utils import make_grid + + g = make_grid([img, from_pil]) + F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png")) + if mode != ImageReadMode.RGB: + # We don't compare against PIL for RGB because results look pretty + # different on RGBA images (other images are fine). The result on + # torchvision basically just plainly ignores the alpha channel, resuting + # in transparent pixels looking dark. PIL seems to be using a sort of + # k-nn thing, looking at the output. Take a look at the resuting images. + torch.testing.assert_close(img, from_pil, rtol=0, atol=3) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/csrc/io/image/cpu/decode_avif.cpp b/torchvision/csrc/io/image/cpu/decode_avif.cpp index ec136743806..5752f04a448 100644 --- a/torchvision/csrc/io/image/cpu/decode_avif.cpp +++ b/torchvision/csrc/io/image/cpu/decode_avif.cpp @@ -8,7 +8,9 @@ namespace vision { namespace image { #if !AVIF_FOUND -torch::Tensor decode_avif(const torch::Tensor& data) { +torch::Tensor decode_avif( + const torch::Tensor& encoded_data, + ImageReadMode mode) { TORCH_CHECK( false, "decode_avif: torchvision not compiled with libavif support"); } @@ -23,7 +25,9 @@ struct UniquePtrDeleter { }; using DecoderPtr = std::unique_ptr; -torch::Tensor decode_avif(const torch::Tensor& encoded_data) { +torch::Tensor decode_avif( + const torch::Tensor& encoded_data, + ImageReadMode mode) { // This is based on // https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c // Refer there for more detail about what each function does, and which @@ -58,9 +62,6 @@ torch::Tensor decode_avif(const torch::Tensor& encoded_data) { avifResultToString(result)); TORCH_CHECK( decoder->imageCount == 1, "Avif file contains more than one image"); - TORCH_CHECK( - decoder->image->depth <= 8, - "avif images with bitdepth > 8 are not supported"); result = avifDecoderNextImage(decoder.get()); TORCH_CHECK( @@ -68,14 +69,36 @@ torch::Tensor decode_avif(const torch::Tensor& encoded_data) { "avifDecoderNextImage failed:", avifResultToString(result)); - auto out = torch::empty( - {decoder->image->height, decoder->image->width, 3}, torch::kUInt8); - avifRGBImage rgb; memset(&rgb, 0, sizeof(rgb)); avifRGBImageSetDefaults(&rgb, decoder->image); - rgb.format = AVIF_RGB_FORMAT_RGB; - rgb.pixels = out.data_ptr(); + + // images encoded as 10 or 12 bits will be decoded as uint16. The rest are + // decoded as uint8. + auto use_uint8 = (decoder->image->depth <= 8); + rgb.depth = use_uint8 ? 8 : 16; + + if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && + mode != IMAGE_READ_MODE_RGB_ALPHA) { + // Other modes aren't supported, but we don't error or even warn because we + // have generic entry points like decode_image which may support all modes, + // it just depends on the underlying decoder. + mode = IMAGE_READ_MODE_UNCHANGED; + } + + // If return_rgb is false it means we return rgba - nothing else. + auto return_rgb = + (mode == IMAGE_READ_MODE_RGB || + (mode == IMAGE_READ_MODE_UNCHANGED && !decoder->alphaPresent)); + + auto num_channels = return_rgb ? 3 : 4; + rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA; + rgb.ignoreAlpha = return_rgb ? AVIF_TRUE : AVIF_FALSE; + + auto out = torch::empty( + {rgb.height, rgb.width, num_channels}, + use_uint8 ? torch::kUInt8 : at::kUInt16); + rgb.pixels = (uint8_t*)out.data_ptr(); rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb); result = avifImageYUVToRGB(decoder->image, &rgb); diff --git a/torchvision/csrc/io/image/cpu/decode_avif.h b/torchvision/csrc/io/image/cpu/decode_avif.h index 269bce52197..0510c2104e5 100644 --- a/torchvision/csrc/io/image/cpu/decode_avif.h +++ b/torchvision/csrc/io/image/cpu/decode_avif.h @@ -1,11 +1,14 @@ #pragma once #include +#include "../image_read_mode.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_avif(const torch::Tensor& data); +C10_EXPORT torch::Tensor decode_avif( + const torch::Tensor& encoded_data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index 75c7e06195a..e5a421b7287 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -58,7 +58,7 @@ torch::Tensor decode_image( 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif" TORCH_CHECK(data.numel() >= 12, err_msg); if ((memcmp(avif_signature, datap + 4, 8) == 0)) { - return decode_avif(data); + return decode_avif(data, mode); } const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" @@ -67,7 +67,7 @@ torch::Tensor decode_image( TORCH_CHECK(data.numel() >= 15, err_msg); if ((memcmp(webp_signature_begin, datap, 4) == 0) && (memcmp(webp_signature_end, datap + 8, 7) == 0)) { - return decode_webp(data); + return decode_webp(data, mode); } TORCH_CHECK(false, err_msg); diff --git a/torchvision/csrc/io/image/cpu/decode_webp.cpp b/torchvision/csrc/io/image/cpu/decode_webp.cpp index 844ce61a3e3..bf115c23c41 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.cpp +++ b/torchvision/csrc/io/image/cpu/decode_webp.cpp @@ -8,13 +8,17 @@ namespace vision { namespace image { #if !WEBP_FOUND -torch::Tensor decode_webp(const torch::Tensor& data) { +torch::Tensor decode_webp( + const torch::Tensor& encoded_data, + ImageReadMode mode) { TORCH_CHECK( false, "decode_webp: torchvision not compiled with libwebp support"); } #else -torch::Tensor decode_webp(const torch::Tensor& encoded_data) { +torch::Tensor decode_webp( + const torch::Tensor& encoded_data, + ImageReadMode mode) { TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); TORCH_CHECK( encoded_data.dtype() == torch::kU8, @@ -26,13 +30,43 @@ torch::Tensor decode_webp(const torch::Tensor& encoded_data) { encoded_data.dim(), " dims."); + auto encoded_data_p = encoded_data.data_ptr(); + auto encoded_data_size = encoded_data.numel(); + + WebPBitstreamFeatures features; + auto res = WebPGetFeatures(encoded_data_p, encoded_data_size, &features); + TORCH_CHECK( + res == VP8_STATUS_OK, "WebPGetFeatures failed with error code ", res); + TORCH_CHECK( + !features.has_animation, "Animated webp files are not supported."); + + if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && + mode != IMAGE_READ_MODE_RGB_ALPHA) { + // Other modes aren't supported, but we don't error or even warn because we + // have generic entry points like decode_image which may support all modes, + // it just depends on the underlying decoder. + mode = IMAGE_READ_MODE_UNCHANGED; + } + + // If return_rgb is false it means we return rgba - nothing else. + auto return_rgb = + (mode == IMAGE_READ_MODE_RGB || + (mode == IMAGE_READ_MODE_UNCHANGED && !features.has_alpha)); + + auto decoding_func = return_rgb ? WebPDecodeRGB : WebPDecodeRGBA; + auto num_channels = return_rgb ? 3 : 4; + int width = 0; int height = 0; - auto decoded_data = WebPDecodeRGB( - encoded_data.data_ptr(), encoded_data.numel(), &width, &height); - TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed."); - auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8); - return out.permute({2, 0, 1}); // return CHW, channels-last + + auto decoded_data = + decoding_func(encoded_data_p, encoded_data_size, &width, &height); + TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed."); + + auto out = torch::from_blob( + decoded_data, {height, width, num_channels}, torch::kUInt8); + + return out.permute({2, 0, 1}); } #endif // WEBP_FOUND diff --git a/torchvision/csrc/io/image/cpu/decode_webp.h b/torchvision/csrc/io/image/cpu/decode_webp.h index 00a0c3362f7..5632ea56ff9 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.h +++ b/torchvision/csrc/io/image/cpu/decode_webp.h @@ -1,11 +1,14 @@ #pragma once #include +#include "../image_read_mode.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data); +C10_EXPORT torch::Tensor decode_webp( + const torch::Tensor& encoded_data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 43e8ecbe4a2..a777d19d3bd 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -21,8 +21,10 @@ static auto registry = .op("image::encode_png", &encode_png) .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_jpeg) - .op("image::decode_webp", &decode_webp) - .op("image::decode_avif", &decode_avif) + .op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor", + &decode_webp) + .op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor", + &decode_avif) .op("image::encode_jpeg", &encode_jpeg) .op("image::read_file", &read_file) .op("image::write_file", &write_file) diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 6d4613f703b..e169c0a4f7a 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -28,6 +28,11 @@ class ImageReadMode(Enum): ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency, ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for RGB with transparency. + + .. note:: + + Some decoders won't support all possible values, e.g. a decoder may only + support "RGB" and "RGBA" mode. """ UNCHANGED = 0 @@ -365,28 +370,52 @@ def decode_gif(input: torch.Tensor) -> torch.Tensor: def decode_webp( input: torch.Tensor, + mode: ImageReadMode = ImageReadMode.UNCHANGED, ) -> torch.Tensor: """ - Decode a WEBP image into a 3 dimensional RGB Tensor. + Decode a WEBP image into a 3 dimensional RGB[A] Tensor. - The values of the output tensor are uint8 between 0 and 255. If the input - image is RGBA, the transparency is ignored. + The values of the output tensor are uint8 between 0 and 255. Args: input (Tensor[1]): a one dimensional contiguous uint8 tensor containing the raw bytes of the WEBP image. + mode (ImageReadMode): The read mode used for optionally + converting the image color space. Default: ``ImageReadMode.UNCHANGED``. + Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. Returns: Decoded image (Tensor[image_channels, image_height, image_width]) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_webp) - return torch.ops.image.decode_webp(input) + return torch.ops.image.decode_webp(input, mode.value) def _decode_avif( input: torch.Tensor, + mode: ImageReadMode = ImageReadMode.UNCHANGED, ) -> torch.Tensor: + """ + Decode an AVIF image into a 3 dimensional RGB[A] Tensor. + + The values of the output tensor are in uint8 in [0, 255] for most images. If + the image has a bit-depth of more than 8, then the output tensor is uint16 + in [0, 65535]. Since uint16 support is limited in pytorch, we recommend + calling :func:`torchvision.transforms.v2.functional.to_dtype()` with + ``scale=True`` after this function to convert the decoded image into a uint8 + or float tensor. + + Args: + input (Tensor[1]): a one dimensional contiguous uint8 tensor containing + the raw bytes of the AVIF image. + mode (ImageReadMode): The read mode used for optionally + converting the image color space. Default: ``ImageReadMode.UNCHANGED``. + Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. + + Returns: + Decoded image (Tensor[image_channels, image_height, image_width]) + """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_webp) - return torch.ops.image.decode_avif(input) + return torch.ops.image.decode_avif(input, mode.value) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 34d1e101dbd..eb75f58cb7a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -66,7 +66,7 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: - """See :class:`~torchvision.transforms.v2.GrayscaleToRgb` for details.""" + """See :class:`~torchvision.transforms.v2.RGB` for details.""" if torch.jit.is_scripting(): return grayscale_to_rgb_image(inpt)