diff --git a/setup.py b/setup.py index 7f383b82ec4..dbe8ce58aa2 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1" USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" +USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "0") == "1" # TODO enable by default! USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default! USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) @@ -50,6 +51,7 @@ print(f"{USE_PNG = }") print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") +print(f"{USE_HEIC = }") print(f"{USE_AVIF = }") print(f"{USE_NVJPEG = }") print(f"{NVCC_FLAGS = }") @@ -334,6 +336,21 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") + if USE_HEIC: + heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif.h") + if heic_found: + print("Building torchvision with HEIC support") + print(f"{heic_include_dir = }") + print(f"{heic_library_dir = }") + if heic_include_dir is not None and heic_library_dir is not None: + # if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add. + include_dirs.append(heic_include_dir) + library_dirs.append(heic_library_dir) + libraries.append("heif") + define_macros += [("HEIC_FOUND", 1)] + else: + warnings.warn("Building torchvision without HEIC support") + if USE_AVIF: avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h") if avif_found: diff --git a/test/assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic b/test/assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic new file mode 100644 index 00000000000..4c29ac3c71c Binary files /dev/null and b/test/assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic differ diff --git a/test/test_image.py b/test/test_image.py index f1fe70135fe..d489b10af7c 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -15,6 +15,7 @@ from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence from torchvision.io.image import ( _decode_avif, + _decode_heic, decode_gif, decode_image, decode_jpeg, @@ -928,11 +929,10 @@ def test_decode_avif(decode_fun, scripted): assert img[None].is_contiguous(memory_format=torch.channels_last) -@pytest.mark.xfail(reason="AVIF support not enabled yet.") +@pytest.mark.xfail(reason="AVIF and HEIC support not enabled yet.") # Note: decode_image fails because some of these files have a (valid) signature # we don't recognize. We should probably use libmagic.... -# @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) -@pytest.mark.parametrize("decode_fun", (_decode_avif,)) +@pytest.mark.parametrize("decode_fun", (_decode_avif, _decode_heic)) @pytest.mark.parametrize("scripted", (False, True)) @pytest.mark.parametrize( "mode, pil_mode", @@ -942,7 +942,9 @@ def test_decode_avif(decode_fun, scripted): (ImageReadMode.UNCHANGED, None), ), ) -@pytest.mark.parametrize("filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif")) +@pytest.mark.parametrize( + "filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name +) def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename): if "reversed_dimg_order" in str(filename): # Pillow properly decodes this one, but we don't (order of parts of the @@ -960,7 +962,14 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename) except RuntimeError as e: if any( s in str(e) - for s in ("BMFF parsing failed", "avifDecoderParse failed: ", "file contains more than one image") + for s in ( + "BMFF parsing failed", + "avifDecoderParse failed: ", + "file contains more than one image", + "no 'ispe' property", + "'iref' has double references", + "Invalid image grid", + ) ): pytest.skip(reason="Expected failure, that's OK") else: @@ -970,22 +979,47 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename) assert img.shape[0] == 3 if mode == ImageReadMode.RGB_ALPHA: assert img.shape[0] == 4 + if img.dtype == torch.uint16: img = F.to_dtype(img, dtype=torch.uint8, scale=True) + try: + from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode)) + except RuntimeError as e: + if "Invalid image grid" in str(e): + pytest.skip(reason="PIL failure") + else: + raise e - from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode)) - if False: + if True: from torchvision.utils import make_grid g = make_grid([img, from_pil]) F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png")) - if mode != ImageReadMode.RGB: - # We don't compare against PIL for RGB because results look pretty - # different on RGBA images (other images are fine). The result on - # torchvision basically just plainly ignores the alpha channel, resuting - # in transparent pixels looking dark. PIL seems to be using a sort of - # k-nn thing, looking at the output. Take a look at the resuting images. - torch.testing.assert_close(img, from_pil, rtol=0, atol=3) + + is__decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic" + if mode == ImageReadMode.RGB and not is__decode_heic: + # We don't compare torchvision's AVIF against PIL for RGB because + # results look pretty different on RGBA images (other images are fine). + # The result on torchvision basically just plainly ignores the alpha + # channel, resuting in transparent pixels looking dark. PIL seems to be + # using a sort of k-nn thing (Take a look at the resuting images) + return + if filename.name == "sofa_grid1x5_420.avif" and is__decode_heic: + return + + torch.testing.assert_close(img, from_pil, rtol=0, atol=3) + + +@pytest.mark.xfail(reason="HEIC support not enabled yet.") +@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image)) +@pytest.mark.parametrize("scripted", (False, True)) +def test_decode_heic(decode_fun, scripted): + encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic"))) + if scripted: + decode_fun = torch.jit.script(decode_fun) + img = decode_fun(encoded_bytes) + assert img.shape == (3, 100, 100) + assert img[None].is_contiguous(memory_format=torch.channels_last) if __name__ == "__main__": diff --git a/torchvision/csrc/io/image/cpu/decode_heic.cpp b/torchvision/csrc/io/image/cpu/decode_heic.cpp new file mode 100644 index 00000000000..148d6043f10 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_heic.cpp @@ -0,0 +1,152 @@ +#include "decode_heic.h" + +#if HEIC_FOUND +#include "libheif/heif_cxx.h" +#endif // HEIC_FOUND + +namespace vision { +namespace image { + +#if !HEIC_FOUND +torch::Tensor decode_heic( + const torch::Tensor& encoded_data, + ImageReadMode mode) { + TORCH_CHECK( + false, "decode_heic: torchvision not compiled with libheif support"); +} +#else + +torch::Tensor decode_heic( + const torch::Tensor& encoded_data, + ImageReadMode mode) { + TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); + TORCH_CHECK( + encoded_data.dtype() == torch::kU8, + "Input tensor must have uint8 data type, got ", + encoded_data.dtype()); + TORCH_CHECK( + encoded_data.dim() == 1, + "Input tensor must be 1-dimensional, got ", + encoded_data.dim(), + " dims."); + + if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && + mode != IMAGE_READ_MODE_RGB_ALPHA) { + // Other modes aren't supported, but we don't error or even warn because we + // have generic entry points like decode_image which may support all modes, + // it just depends on the underlying decoder. + mode = IMAGE_READ_MODE_UNCHANGED; + } + + // If return_rgb is false it means we return rgba - nothing else. + auto return_rgb = true; + + int height = 0; + int width = 0; + int num_channels = 0; + int stride = 0; + uint8_t* decoded_data = nullptr; + heif::Image img; + int bit_depth = 0; + + try { + heif::Context ctx; + ctx.read_from_memory_without_copy( + encoded_data.data_ptr(), encoded_data.numel()); + + // TODO properly error on (or support) image sequences. Right now, I think + // this function will always return the first image in a sequence, which is + // inconsistent with decode_gif (which returns a batch) and with decode_avif + // (which errors loudly). + // Why? I'm struggling to make sense of + // ctx.get_number_of_top_level_images(). It disagrees with libavif's + // imageCount. For example on some of the libavif test images: + // + // - colors-animated-12bpc-keyframes-0-2-3.avif + // avif num images = 5 + // heif num images = 1 // Why is this 1 when clearly this is supposed to + // be a sequence? + // - sofa_grid1x5_420.avif + // avif num images = 1 + // heif num images = 6 // If we were to error here we won't be able to + // decode this image which is otherwise properly + // decoded by libavif. + // I can't find a libheif function that does what we need here, or at least + // that agrees with libavif. + + // TORCH_CHECK( + // ctx.get_number_of_top_level_images() == 1, + // "heic file contains more than one image"); + + heif::ImageHandle handle = ctx.get_primary_image_handle(); + bit_depth = handle.get_luma_bits_per_pixel(); + + return_rgb = + (mode == IMAGE_READ_MODE_RGB || + (mode == IMAGE_READ_MODE_UNCHANGED && !handle.has_alpha_channel())); + + height = handle.get_height(); + width = handle.get_width(); + + num_channels = return_rgb ? 3 : 4; + heif_chroma chroma; + if (bit_depth == 8) { + chroma = return_rgb ? heif_chroma_interleaved_RGB + : heif_chroma_interleaved_RGBA; + } else { + // TODO: This, along with our 10bits -> 16bits range mapping down below, + // may not work on BE platforms + chroma = return_rgb ? heif_chroma_interleaved_RRGGBB_LE + : heif_chroma_interleaved_RRGGBBAA_LE; + } + + img = handle.decode_image(heif_colorspace_RGB, chroma); + + decoded_data = img.get_plane(heif_channel_interleaved, &stride); + } catch (const heif::Error& err) { + // We need this try/catch block and call TORCH_CHECK, because libheif may + // otherwise throw heif::Error that would just be reported as "An unknown + // exception occurred" when we move back to Python. + TORCH_CHECK(false, "decode_heif failed: ", err.get_message()); + } + TORCH_CHECK(decoded_data != nullptr, "Something went wrong during decoding."); + + auto dtype = (bit_depth == 8) ? torch::kUInt8 : at::kUInt16; + auto out = torch::empty({height, width, num_channels}, dtype); + uint8_t* out_ptr = (uint8_t*)out.data_ptr(); + + // decoded_data is *almost* the raw decoded data, but not quite: for some + // images, there may be some padding at the end of each row, i.e. when stride + // != row_size_in_bytes. So we can't copy decoded_data into the tensor's + // memory directly, we have to copy row by row. Oh, and if you think you can + // take a shortcut when stride == row_size_in_bytes and just do: + // out = torch::from_blob(decoded_data, ...) + // you can't, because decoded_data is owned by the heif::Image object and it + // gets freed when it gets out of scope! + auto row_size_in_bytes = width * num_channels * ((bit_depth == 8) ? 1 : 2); + for (auto h = 0; h < height; h++) { + memcpy( + out_ptr + h * row_size_in_bytes, + decoded_data + h * stride, + row_size_in_bytes); + } + if (bit_depth > 8) { + // Say bit depth is 10. decodec_data and out_ptr contain 10bits values + // over 2 bytes, stored into uint16_t. In torchvision a uint16 value is + // expected to be in [0, 2**16), so we have to map the 10bits value to that + // range. Note that other libraries like libavif do that mapping + // automatically. + // TODO: It's possible to avoid the memcpy call above in this case, and do + // the copy at the same time as the conversation. Whether it's worth it + // should be benchmarked. + auto out_ptr_16 = (uint16_t*)out_ptr; + for (auto p = 0; p < height * width * num_channels; p++) { + out_ptr_16[p] <<= (16 - bit_depth); + } + } + return out.permute({2, 0, 1}); +} +#endif // HEIC_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_heic.h b/torchvision/csrc/io/image/cpu/decode_heic.h new file mode 100644 index 00000000000..4a23e4c1431 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_heic.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_heic( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index e5a421b7287..9c1a7ff3ef4 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -2,6 +2,7 @@ #include "decode_avif.h" #include "decode_gif.h" +#include "decode_heic.h" #include "decode_jpeg.h" #include "decode_png.h" #include "decode_webp.h" @@ -61,6 +62,17 @@ torch::Tensor decode_image( return decode_avif(data, mode); } + // Similarly for heic we assume the signature is "ftypeheic" but some files + // may come as "ftypmif1" where the "heic" part is defined later in the file. + // We can't be re-inventing libmagic here. We might need to start relying on + // it though... + const uint8_t heic_signature[8] = { + 0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63}; // == "ftypheic" + TORCH_CHECK(data.numel() >= 12, err_msg); + if ((memcmp(heic_signature, datap + 4, 8) == 0)) { + return decode_heic(data, mode); + } + const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" const uint8_t webp_signature_end[7] = { 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index a777d19d3bd..f0ce91144a6 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -23,6 +23,8 @@ static auto registry = &decode_jpeg) .op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor", &decode_webp) + .op("image::decode_heic(Tensor encoded_data, int mode) -> Tensor", + &decode_heic) .op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor", &decode_avif) .op("image::encode_jpeg", &encode_jpeg) diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index 91a5144fa1c..23493f3c030 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -2,6 +2,7 @@ #include "cpu/decode_avif.h" #include "cpu/decode_gif.h" +#include "cpu/decode_heic.h" #include "cpu/decode_image.h" #include "cpu/decode_jpeg.h" #include "cpu/decode_png.h" diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 08a0d6d62b7..a604ea1fdb6 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -61,6 +61,7 @@ "decode_image", "decode_jpeg", "decode_png", + "decode_heic", "decode_webp", "decode_gif", "encode_jpeg", diff --git a/torchvision/io/image.py b/torchvision/io/image.py index e169c0a4f7a..f1df0d52672 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -417,5 +417,31 @@ def _decode_avif( Decoded image (Tensor[image_channels, image_height, image_width]) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(decode_webp) + _log_api_usage_once(_decode_avif) return torch.ops.image.decode_avif(input, mode.value) + + +def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: + """ + Decode an HEIC image into a 3 dimensional RGB[A] Tensor. + + The values of the output tensor are in uint8 in [0, 255] for most images. If + the image has a bit-depth of more than 8, then the output tensor is uint16 + in [0, 65535]. Since uint16 support is limited in pytorch, we recommend + calling :func:`torchvision.transforms.v2.functional.to_dtype()` with + ``scale=True`` after this function to convert the decoded image into a uint8 + or float tensor. + + Args: + input (Tensor[1]): a one dimensional contiguous uint8 tensor containing + the raw bytes of the HEIC image. + mode (ImageReadMode): The read mode used for optionally + converting the image color space. Default: ``ImageReadMode.UNCHANGED``. + Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. + + Returns: + Decoded image (Tensor[image_channels, image_height, image_width]) + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(_decode_heic) + return torch.ops.image.decode_heic(input, mode.value)