diff --git a/CMakeLists.txt b/CMakeLists.txt index cf305c4ec17..f2430559909 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,8 @@ option(WITH_JPEG "Enable features requiring LibJPEG." ON) # untested. Since building from cmake is very low pri anyway, this is OK. If # you're a user and you need this, please open an issue (and a PR!). option(WITH_WEBP "Enable features requiring LibWEBP." OFF) +# Same here +option(WITH_AVIF "Enable features requiring LibAVIF." OFF) if(WITH_CUDA) enable_language(CUDA) @@ -41,6 +43,11 @@ if (WITH_WEBP) find_package(WEBP REQUIRED) endif() +if (WITH_AVIF) + add_definitions(-DAVIF_FOUND) + find_package(AVIF REQUIRED) +endif() + function(CUDA_CONVERT_FLAGS EXISTING_TARGET) get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS) if(NOT "${old_flags}" STREQUAL "") @@ -117,6 +124,10 @@ if (WITH_WEBP) target_link_libraries(${PROJECT_NAME} PRIVATE ${WEBP_LIBRARIES}) endif() +if (WITH_AVIF) + target_link_libraries(${PROJECT_NAME} PRIVATE ${AVIF_LIBRARIES}) +endif() + set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchVision INSTALL_RPATH ${TORCH_INSTALL_PREFIX}/lib) @@ -135,6 +146,10 @@ if (WITH_WEBP) include_directories(${WEBP_INCLUDE_DIRS}) endif() +if (WITH_AVIF) + include_directories(${AVIF_INCLUDE_DIRS}) +endif() + set(TORCHVISION_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchVision" CACHE STRING "install path for TorchVisionConfig.cmake") configure_package_config_file(cmake/TorchVisionConfig.cmake.in diff --git a/setup.py b/setup.py index fb3b503e6e6..7f383b82ec4 100644 --- a/setup.py +++ b/setup.py @@ -19,10 +19,17 @@ USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1" USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" +USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default! USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) -USE_FFMPEG = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1" -USE_VIDEO_CODEC = os.getenv("TORCHVISION_USE_VIDEO_CODEC", "1") == "1" +# Note: the GPU video decoding stuff used to be called "video codec", which +# isn't an accurate or descriptive name considering there are at least 2 other +# video deocding backends in torchvision. I'm renaming this to "gpu video +# decoder" where possible, keeping user facing names (like the env var below) to +# the old scheme for BC. +USE_GPU_VIDEO_DECODER = os.getenv("TORCHVISION_USE_VIDEO_CODEC", "1") == "1" +# Same here: "use ffmpeg" was used to denote "use cpu video decoder". +USE_CPU_VIDEO_DECODER = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1" TORCHVISION_INCLUDE = os.environ.get("TORCHVISION_INCLUDE", "") TORCHVISION_LIBRARY = os.environ.get("TORCHVISION_LIBRARY", "") @@ -43,10 +50,11 @@ print(f"{USE_PNG = }") print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") +print(f"{USE_AVIF = }") print(f"{USE_NVJPEG = }") print(f"{NVCC_FLAGS = }") -print(f"{USE_FFMPEG = }") -print(f"{USE_VIDEO_CODEC = }") +print(f"{USE_CPU_VIDEO_DECODER = }") +print(f"{USE_GPU_VIDEO_DECODER = }") print(f"{TORCHVISION_INCLUDE = }") print(f"{TORCHVISION_LIBRARY = }") print(f"{IS_ROCM = }") @@ -326,6 +334,21 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") + if USE_AVIF: + avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h") + if avif_found: + print("Building torchvision with AVIF support") + print(f"{avif_include_dir = }") + print(f"{avif_library_dir = }") + if avif_include_dir is not None and avif_library_dir is not None: + # if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add. + include_dirs.append(avif_include_dir) + library_dirs.append(avif_library_dir) + libraries.append("avif") + define_macros += [("AVIF_FOUND", 1)] + else: + warnings.warn("Building torchvision without AVIF support") + if USE_NVJPEG and torch.cuda.is_available(): nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() @@ -351,28 +374,21 @@ def make_image_extension(): def make_video_decoders_extensions(): print("Building video decoder extensions") - # Locating ffmpeg - ffmpeg_exe = shutil.which("ffmpeg") - has_ffmpeg = ffmpeg_exe is not None - ffmpeg_version = None - # FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9 - # FIXME: causes crash. See the following GitHub issues for more details. - # FIXME: https://github.com/pytorch/pytorch/issues/65000 - # FIXME: https://github.com/pytorch/vision/issues/3367 + build_without_extensions_msg = "Building without video decoders extensions." if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9): - has_ffmpeg = False - if has_ffmpeg: - try: - # This is to check if ffmpeg is installed properly. - ffmpeg_version = subprocess.check_output(["ffmpeg", "-version"]) - except subprocess.CalledProcessError: - print("Building torchvision without ffmpeg support") - print(" Error fetching ffmpeg version, ignoring ffmpeg.") - has_ffmpeg = False + # FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9 + # FIXME: causes crash. See the following GitHub issues for more details. + # FIXME: https://github.com/pytorch/pytorch/issues/65000 + # FIXME: https://github.com/pytorch/vision/issues/3367 + print("Can only build video decoder extensions on linux and Python != 3.9") + return [] - use_ffmpeg = USE_FFMPEG and has_ffmpeg + ffmpeg_exe = shutil.which("ffmpeg") + if ffmpeg_exe is None: + print(f"{build_without_extensions_msg} Couldn't find ffmpeg binary.") + return [] - if use_ffmpeg: + def find_ffmpeg_libraries(): ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"} ffmpeg_bin = os.path.dirname(ffmpeg_exe) @@ -399,18 +415,23 @@ def make_video_decoders_extensions(): library_found |= len(glob.glob(full_path)) > 0 if not library_found: - print("Building torchvision without ffmpeg support") - print(f" {library} header files were not found, disabling ffmpeg support") - use_ffmpeg = False - else: - print("Building torchvision without ffmpeg support") + print(f"{build_without_extensions_msg}") + print(f"{library} header files were not found.") + return None, None + + return ffmpeg_include_dir, ffmpeg_library_dir + + ffmpeg_include_dir, ffmpeg_library_dir = find_ffmpeg_libraries() + if ffmpeg_include_dir is None or ffmpeg_library_dir is None: + return [] + + print("Found ffmpeg:") + print(f" ffmpeg include path: {ffmpeg_include_dir}") + print(f" ffmpeg library_dir: {ffmpeg_library_dir}") extensions = [] - if use_ffmpeg: - print("Building torchvision with ffmpeg support") - print(f" ffmpeg version: {ffmpeg_version}") - print(f" ffmpeg include path: {ffmpeg_include_dir}") - print(f" ffmpeg library_dir: {ffmpeg_library_dir}") + if USE_CPU_VIDEO_DECODER: + print("Building with CPU video decoder support") # TorchVision base decoder + video reader video_reader_src_dir = os.path.join(ROOT_DIR, "torchvision", "csrc", "io", "video_reader") @@ -427,6 +448,7 @@ def make_video_decoders_extensions(): extensions.append( CppExtension( + # This is an aweful name. It should be "cpu_video_decoder". Keeping for BC. "torchvision.video_reader", combined_src, include_dirs=[ @@ -450,25 +472,24 @@ def make_video_decoders_extensions(): ) ) - # Locating video codec - # CUDA_HOME should be set to the cuda root directory. - # TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to - # video codec header files and libraries respectively. - video_codec_found = ( - BUILD_CUDA_SOURCES - and CUDA_HOME is not None - and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in TORCHVISION_INCLUDE]) - and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in TORCHVISION_INCLUDE]) - and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in TORCHVISION_LIBRARY]) - ) + if USE_GPU_VIDEO_DECODER: + # Locating GPU video decoder headers and libraries + # CUDA_HOME should be set to the cuda root directory. + # TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the locations + # to the headers and libraries below + if not ( + BUILD_CUDA_SOURCES + and CUDA_HOME is not None + and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in TORCHVISION_INCLUDE]) + and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in TORCHVISION_INCLUDE]) + and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in TORCHVISION_LIBRARY]) + and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir]) + ): + print("Could not find necessary dependencies. Refer the setup.py to check which ones are needed.") + print("Building without GPU video decoder support") + return extensions + print("Building torchvision with GPU video decoder support") - use_video_codec = USE_VIDEO_CODEC and video_codec_found - if ( - use_video_codec - and use_ffmpeg - and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir]) - ): - print("Building torchvision with video codec support") gpu_decoder_path = os.path.join(CSRS_DIR, "io", "decoder", "gpu") gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp")) cuda_libs = os.path.join(CUDA_HOME, "lib64") @@ -477,7 +498,7 @@ def make_video_decoders_extensions(): _, extra_compile_args = get_macros_and_flags() extensions.append( CUDAExtension( - "torchvision.Decoder", + "torchvision.gpu_decoder", gpu_decoder_src, include_dirs=[CSRS_DIR] + TORCHVISION_INCLUDE + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir, library_dirs=ffmpeg_library_dir + TORCHVISION_LIBRARY + [cuda_libs], @@ -498,18 +519,6 @@ def make_video_decoders_extensions(): extra_compile_args=extra_compile_args, ) ) - else: - print("Building torchvision without video codec support") - if ( - use_video_codec - and use_ffmpeg - and not any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir]) - ): - print( - " The installed version of ffmpeg is missing the header file 'bsf.h' which is " - " required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:" - " `conda install -c conda-forge ffmpeg`." - ) return extensions diff --git a/test/assets/fakedata/logos/rgb_pytorch.avif b/test/assets/fakedata/logos/rgb_pytorch.avif new file mode 100644 index 00000000000..ea1bb586957 Binary files /dev/null and b/test/assets/fakedata/logos/rgb_pytorch.avif differ diff --git a/test/test_image.py b/test/test_image.py index cce7d6e0ff7..f1fe70135fe 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -14,6 +14,7 @@ from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence from torchvision.io.image import ( + _decode_avif, decode_gif, decode_image, decode_jpeg, @@ -873,8 +874,8 @@ def test_decode_gif_webp_errors(decode_fun): decode_fun(encoded_data[::2]) if decode_fun is decode_gif: expected_match = re.escape("DGifOpenFileName() failed - 103") - else: - expected_match = "WebPDecodeRGB failed." + elif decode_fun is decode_webp: + expected_match = "WebPGetFeatures failed." with pytest.raises(RuntimeError, match=expected_match): decode_fun(encoded_data) @@ -890,5 +891,102 @@ def test_decode_webp(decode_fun, scripted): assert img[None].is_contiguous(memory_format=torch.channels_last) +# This test is skipped because it requires webp images that we're not including +# within the repo. The test images were downloaded from the different pages of +# https://developers.google.com/speed/webp/gallery +# Note that converting an RGBA image to RGB leads to bad results because the +# transparent pixels aren't necessarily set to "black" or "white", they can be +# random stuff. This is consistent with PIL results. +@pytest.mark.skip(reason="Need to download test images first") +@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image)) +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize( + "mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None)) +) +@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp")) +def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename): + encoded_bytes = read_file(filename) + if scripted: + decode_fun = torch.jit.script(decode_fun) + img = decode_fun(encoded_bytes, mode=mode) + assert img[None].is_contiguous(memory_format=torch.channels_last) + + pil_img = Image.open(filename).convert(pil_mode) + from_pil = F.pil_to_tensor(pil_img) + assert_equal(img, from_pil) + + +@pytest.mark.xfail(reason="AVIF support not enabled yet.") +@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) +@pytest.mark.parametrize("scripted", (False, True)) +def test_decode_avif(decode_fun, scripted): + encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif"))) + if scripted: + decode_fun = torch.jit.script(decode_fun) + img = decode_fun(encoded_bytes) + assert img.shape == (3, 100, 100) + assert img[None].is_contiguous(memory_format=torch.channels_last) + + +@pytest.mark.xfail(reason="AVIF support not enabled yet.") +# Note: decode_image fails because some of these files have a (valid) signature +# we don't recognize. We should probably use libmagic.... +# @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) +@pytest.mark.parametrize("decode_fun", (_decode_avif,)) +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize( + "mode, pil_mode", + ( + (ImageReadMode.RGB, "RGB"), + (ImageReadMode.RGB_ALPHA, "RGBA"), + (ImageReadMode.UNCHANGED, None), + ), +) +@pytest.mark.parametrize("filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif")) +def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename): + if "reversed_dimg_order" in str(filename): + # Pillow properly decodes this one, but we don't (order of parts of the + # image is wrong). This is due to a bug that was recently fixed in + # libavif. Hopefully this test will end up passing soon with a new + # libavif version https://github.com/AOMediaCodec/libavif/issues/2311 + pytest.xfail() + import pillow_avif # noqa + + encoded_bytes = read_file(filename) + if scripted: + decode_fun = torch.jit.script(decode_fun) + try: + img = decode_fun(encoded_bytes, mode=mode) + except RuntimeError as e: + if any( + s in str(e) + for s in ("BMFF parsing failed", "avifDecoderParse failed: ", "file contains more than one image") + ): + pytest.skip(reason="Expected failure, that's OK") + else: + raise e + assert img[None].is_contiguous(memory_format=torch.channels_last) + if mode == ImageReadMode.RGB: + assert img.shape[0] == 3 + if mode == ImageReadMode.RGB_ALPHA: + assert img.shape[0] == 4 + if img.dtype == torch.uint16: + img = F.to_dtype(img, dtype=torch.uint8, scale=True) + + from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode)) + if False: + from torchvision.utils import make_grid + + g = make_grid([img, from_pil]) + F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png")) + if mode != ImageReadMode.RGB: + # We don't compare against PIL for RGB because results look pretty + # different on RGBA images (other images are fine). The result on + # torchvision basically just plainly ignores the alpha channel, resuting + # in transparent pixels looking dark. PIL seems to be using a sort of + # k-nn thing, looking at the output. Take a look at the resuting images. + torch.testing.assert_close(img, from_pil, rtol=0, atol=3) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_io.py b/test/test_io.py index 1b7b7eb15a1..d2950ac9595 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -63,7 +63,7 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, @pytest.mark.skipif( - get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT, reason="video_reader backend not available" + get_video_backend() != "pyav" and not io._HAS_CPU_VIDEO_DECODER, reason="video_reader backend not available" ) @pytest.mark.skipif(av is None, reason="PyAV unavailable") class TestVideo: @@ -77,14 +77,14 @@ def test_write_read_video(self): assert_equal(data, lv) assert info["video_fps"] == 5 - @pytest.mark.skipif(not io._HAS_VIDEO_OPT, reason="video_reader backend is not chosen") + @pytest.mark.skipif(not io._HAS_CPU_VIDEO_DECODER, reason="video_reader backend is not chosen") def test_probe_video_from_file(self): with temp_video(10, 300, 300, 5) as (f_name, data): video_info = io._probe_video_from_file(f_name) assert pytest.approx(2, rel=0.0, abs=0.1) == video_info.video_duration assert pytest.approx(5, rel=0.0, abs=0.1) == video_info.video_fps - @pytest.mark.skipif(not io._HAS_VIDEO_OPT, reason="video_reader backend is not chosen") + @pytest.mark.skipif(not io._HAS_CPU_VIDEO_DECODER, reason="video_reader backend is not chosen") def test_probe_video_from_memory(self): with temp_video(10, 300, 300, 5) as (f_name, data): with open(f_name, "rb") as fp: diff --git a/test/test_video_reader.py b/test/test_video_reader.py index 243aa12fc12..10995424982 100644 --- a/test/test_video_reader.py +++ b/test/test_video_reader.py @@ -11,7 +11,7 @@ from numpy.random import randint from pytest import approx from torchvision import set_video_backend -from torchvision.io import _HAS_VIDEO_OPT +from torchvision.io import _HAS_CPU_VIDEO_DECODER try: @@ -263,7 +263,7 @@ def _get_video_tensor(video_dir, video_file): @pytest.mark.skipif(av is None, reason="PyAV unavailable") -@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg") +@pytest.mark.skipif(_HAS_CPU_VIDEO_DECODER is False, reason="Didn't compile with ffmpeg") class TestVideoReader: def check_separate_decoding_result(self, tv_result, config): """check the decoding results from TorchVision decoder""" diff --git a/test/test_videoapi.py b/test/test_videoapi.py index dc878ca9f8c..aabcf6407f7 100644 --- a/test/test_videoapi.py +++ b/test/test_videoapi.py @@ -7,7 +7,7 @@ import torchvision from pytest import approx from torchvision.datasets.utils import download_url -from torchvision.io import _HAS_VIDEO_OPT, VideoReader +from torchvision.io import _HAS_CPU_VIDEO_DECODER, VideoReader # WARNING: these tests have been skipped forever on the CI because the video ops @@ -62,7 +62,7 @@ def fate(name, path="."): } -@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg") +@pytest.mark.skipif(_HAS_CPU_VIDEO_DECODER is False, reason="Didn't compile with ffmpeg") class TestVideoApi: @pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.parametrize("test_video", test_videos.keys()) diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 857625a783c..5d06156c25f 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -72,7 +72,7 @@ def set_video_backend(backend): global _video_backend if backend not in ["pyav", "video_reader", "cuda"]: raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend) - if backend == "video_reader" and not io._HAS_VIDEO_OPT: + if backend == "video_reader" and not io._HAS_CPU_VIDEO_DECODER: # TODO: better messages message = "video_reader video backend is not available. Please compile torchvision from source and try again" raise RuntimeError(message) diff --git a/torchvision/csrc/io/image/cpu/decode_avif.cpp b/torchvision/csrc/io/image/cpu/decode_avif.cpp new file mode 100644 index 00000000000..5752f04a448 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_avif.cpp @@ -0,0 +1,115 @@ +#include "decode_avif.h" + +#if AVIF_FOUND +#include "avif/avif.h" +#endif // AVIF_FOUND + +namespace vision { +namespace image { + +#if !AVIF_FOUND +torch::Tensor decode_avif( + const torch::Tensor& encoded_data, + ImageReadMode mode) { + TORCH_CHECK( + false, "decode_avif: torchvision not compiled with libavif support"); +} +#else + +// This normally comes from avif_cxx.h, but it's not always present when +// installing libavif. So we just copy/paste it here. +struct UniquePtrDeleter { + void operator()(avifDecoder* decoder) const { + avifDecoderDestroy(decoder); + } +}; +using DecoderPtr = std::unique_ptr; + +torch::Tensor decode_avif( + const torch::Tensor& encoded_data, + ImageReadMode mode) { + // This is based on + // https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c + // Refer there for more detail about what each function does, and which + // structure/data is available after which call. + + TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); + TORCH_CHECK( + encoded_data.dtype() == torch::kU8, + "Input tensor must have uint8 data type, got ", + encoded_data.dtype()); + TORCH_CHECK( + encoded_data.dim() == 1, + "Input tensor must be 1-dimensional, got ", + encoded_data.dim(), + " dims."); + + DecoderPtr decoder(avifDecoderCreate()); + TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder."); + + auto result = AVIF_RESULT_UNKNOWN_ERROR; + result = avifDecoderSetIOMemory( + decoder.get(), encoded_data.data_ptr(), encoded_data.numel()); + TORCH_CHECK( + result == AVIF_RESULT_OK, + "avifDecoderSetIOMemory failed:", + avifResultToString(result)); + + result = avifDecoderParse(decoder.get()); + TORCH_CHECK( + result == AVIF_RESULT_OK, + "avifDecoderParse failed: ", + avifResultToString(result)); + TORCH_CHECK( + decoder->imageCount == 1, "Avif file contains more than one image"); + + result = avifDecoderNextImage(decoder.get()); + TORCH_CHECK( + result == AVIF_RESULT_OK, + "avifDecoderNextImage failed:", + avifResultToString(result)); + + avifRGBImage rgb; + memset(&rgb, 0, sizeof(rgb)); + avifRGBImageSetDefaults(&rgb, decoder->image); + + // images encoded as 10 or 12 bits will be decoded as uint16. The rest are + // decoded as uint8. + auto use_uint8 = (decoder->image->depth <= 8); + rgb.depth = use_uint8 ? 8 : 16; + + if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && + mode != IMAGE_READ_MODE_RGB_ALPHA) { + // Other modes aren't supported, but we don't error or even warn because we + // have generic entry points like decode_image which may support all modes, + // it just depends on the underlying decoder. + mode = IMAGE_READ_MODE_UNCHANGED; + } + + // If return_rgb is false it means we return rgba - nothing else. + auto return_rgb = + (mode == IMAGE_READ_MODE_RGB || + (mode == IMAGE_READ_MODE_UNCHANGED && !decoder->alphaPresent)); + + auto num_channels = return_rgb ? 3 : 4; + rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA; + rgb.ignoreAlpha = return_rgb ? AVIF_TRUE : AVIF_FALSE; + + auto out = torch::empty( + {rgb.height, rgb.width, num_channels}, + use_uint8 ? torch::kUInt8 : at::kUInt16); + rgb.pixels = (uint8_t*)out.data_ptr(); + rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb); + + result = avifImageYUVToRGB(decoder->image, &rgb); + TORCH_CHECK( + result == AVIF_RESULT_OK, + "avifImageYUVToRGB failed: ", + avifResultToString(result)); + + return out.permute({2, 0, 1}); // return CHW, channels-last +} +#endif // AVIF_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_avif.h b/torchvision/csrc/io/image/cpu/decode_avif.h new file mode 100644 index 00000000000..0510c2104e5 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_avif.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_avif( + const torch::Tensor& encoded_data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index ed527a44b31..e5a421b7287 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -1,5 +1,6 @@ #include "decode_image.h" +#include "decode_avif.h" #include "decode_gif.h" #include "decode_jpeg.h" #include "decode_png.h" @@ -48,13 +49,25 @@ torch::Tensor decode_image( return decode_gif(data); } + // We assume the signature of an avif file is + // 0000 0020 6674 7970 6176 6966 + // xxxx xxxx f t y p a v i f + // We only check for the "ftyp avif" part. + // This is probably not perfect, but hopefully this should cover most files. + const uint8_t avif_signature[8] = { + 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif" + TORCH_CHECK(data.numel() >= 12, err_msg); + if ((memcmp(avif_signature, datap + 4, 8) == 0)) { + return decode_avif(data, mode); + } + const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" const uint8_t webp_signature_end[7] = { 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" TORCH_CHECK(data.numel() >= 15, err_msg); if ((memcmp(webp_signature_begin, datap, 4) == 0) && (memcmp(webp_signature_end, datap + 8, 7) == 0)) { - return decode_webp(data); + return decode_webp(data, mode); } TORCH_CHECK(false, err_msg); diff --git a/torchvision/csrc/io/image/cpu/decode_webp.cpp b/torchvision/csrc/io/image/cpu/decode_webp.cpp index 844ce61a3e3..bf115c23c41 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.cpp +++ b/torchvision/csrc/io/image/cpu/decode_webp.cpp @@ -8,13 +8,17 @@ namespace vision { namespace image { #if !WEBP_FOUND -torch::Tensor decode_webp(const torch::Tensor& data) { +torch::Tensor decode_webp( + const torch::Tensor& encoded_data, + ImageReadMode mode) { TORCH_CHECK( false, "decode_webp: torchvision not compiled with libwebp support"); } #else -torch::Tensor decode_webp(const torch::Tensor& encoded_data) { +torch::Tensor decode_webp( + const torch::Tensor& encoded_data, + ImageReadMode mode) { TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); TORCH_CHECK( encoded_data.dtype() == torch::kU8, @@ -26,13 +30,43 @@ torch::Tensor decode_webp(const torch::Tensor& encoded_data) { encoded_data.dim(), " dims."); + auto encoded_data_p = encoded_data.data_ptr(); + auto encoded_data_size = encoded_data.numel(); + + WebPBitstreamFeatures features; + auto res = WebPGetFeatures(encoded_data_p, encoded_data_size, &features); + TORCH_CHECK( + res == VP8_STATUS_OK, "WebPGetFeatures failed with error code ", res); + TORCH_CHECK( + !features.has_animation, "Animated webp files are not supported."); + + if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && + mode != IMAGE_READ_MODE_RGB_ALPHA) { + // Other modes aren't supported, but we don't error or even warn because we + // have generic entry points like decode_image which may support all modes, + // it just depends on the underlying decoder. + mode = IMAGE_READ_MODE_UNCHANGED; + } + + // If return_rgb is false it means we return rgba - nothing else. + auto return_rgb = + (mode == IMAGE_READ_MODE_RGB || + (mode == IMAGE_READ_MODE_UNCHANGED && !features.has_alpha)); + + auto decoding_func = return_rgb ? WebPDecodeRGB : WebPDecodeRGBA; + auto num_channels = return_rgb ? 3 : 4; + int width = 0; int height = 0; - auto decoded_data = WebPDecodeRGB( - encoded_data.data_ptr(), encoded_data.numel(), &width, &height); - TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed."); - auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8); - return out.permute({2, 0, 1}); // return CHW, channels-last + + auto decoded_data = + decoding_func(encoded_data_p, encoded_data_size, &width, &height); + TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed."); + + auto out = torch::from_blob( + decoded_data, {height, width, num_channels}, torch::kUInt8); + + return out.permute({2, 0, 1}); } #endif // WEBP_FOUND diff --git a/torchvision/csrc/io/image/cpu/decode_webp.h b/torchvision/csrc/io/image/cpu/decode_webp.h index 00a0c3362f7..5632ea56ff9 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.h +++ b/torchvision/csrc/io/image/cpu/decode_webp.h @@ -1,11 +1,14 @@ #pragma once #include +#include "../image_read_mode.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data); +C10_EXPORT torch::Tensor decode_webp( + const torch::Tensor& encoded_data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 8ca2f814996..a777d19d3bd 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -21,7 +21,10 @@ static auto registry = .op("image::encode_png", &encode_png) .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_jpeg) - .op("image::decode_webp", &decode_webp) + .op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor", + &decode_webp) + .op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor", + &decode_avif) .op("image::encode_jpeg", &encode_jpeg) .op("image::read_file", &read_file) .op("image::write_file", &write_file) diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index 3f47fdec65c..91a5144fa1c 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -1,5 +1,6 @@ #pragma once +#include "cpu/decode_avif.h" #include "cpu/decode_gif.h" #include "cpu/decode_image.h" #include "cpu/decode_jpeg.h" diff --git a/torchvision/datasets/moving_mnist.py b/torchvision/datasets/moving_mnist.py index d02811762b8..48715de4e8d 100644 --- a/torchvision/datasets/moving_mnist.py +++ b/torchvision/datasets/moving_mnist.py @@ -66,7 +66,7 @@ def __init__( def __getitem__(self, idx: int) -> torch.Tensor: """ Args: - index (int): Index + idx (int): Index Returns: torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames. """ diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 780b6ab333e..08a0d6d62b7 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -10,6 +10,7 @@ _HAS_GPU_VIDEO_DECODER = False from ._video_opt import ( + _HAS_CPU_VIDEO_DECODER, _HAS_VIDEO_OPT, _probe_video_from_file, _probe_video_from_memory, @@ -49,6 +50,7 @@ "_read_video_from_memory", "_read_video_timestamps_from_memory", "_probe_video_from_memory", + "_HAS_CPU_VIDEO_DECODER", "_HAS_VIDEO_OPT", "_HAS_GPU_VIDEO_DECODER", "_read_video_clip_from_memory", @@ -59,6 +61,8 @@ "decode_image", "decode_jpeg", "decode_png", + "decode_webp", + "decode_gif", "encode_jpeg", "encode_png", "read_file", diff --git a/torchvision/io/_load_gpu_decoder.py b/torchvision/io/_load_gpu_decoder.py index f7869f0a9d1..cfd40c545d8 100644 --- a/torchvision/io/_load_gpu_decoder.py +++ b/torchvision/io/_load_gpu_decoder.py @@ -2,7 +2,7 @@ try: - _load_library("Decoder") + _load_library("gpu_decoder") _HAS_GPU_VIDEO_DECODER = True except (ImportError, OSError): _HAS_GPU_VIDEO_DECODER = False diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 2bd7d11929e..69af045e773 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -10,10 +10,11 @@ try: _load_library("video_reader") - _HAS_VIDEO_OPT = True + _HAS_CPU_VIDEO_DECODER = True except (ImportError, OSError): - _HAS_VIDEO_OPT = False + _HAS_CPU_VIDEO_DECODER = False +_HAS_VIDEO_OPT = _HAS_CPU_VIDEO_DECODER # For BC default_timebase = Fraction(0, 1) diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 3414e280e68..e169c0a4f7a 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -28,6 +28,11 @@ class ImageReadMode(Enum): ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency, ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for RGB with transparency. + + .. note:: + + Some decoders won't support all possible values, e.g. a decoder may only + support "RGB" and "RGBA" mode. """ UNCHANGED = 0 @@ -365,20 +370,52 @@ def decode_gif(input: torch.Tensor) -> torch.Tensor: def decode_webp( input: torch.Tensor, + mode: ImageReadMode = ImageReadMode.UNCHANGED, ) -> torch.Tensor: """ - Decode a WEBP image into a 3 dimensional RGB Tensor. + Decode a WEBP image into a 3 dimensional RGB[A] Tensor. - The values of the output tensor are uint8 between 0 and 255. If the input - image is RGBA, the transparency is ignored. + The values of the output tensor are uint8 between 0 and 255. Args: input (Tensor[1]): a one dimensional contiguous uint8 tensor containing the raw bytes of the WEBP image. + mode (ImageReadMode): The read mode used for optionally + converting the image color space. Default: ``ImageReadMode.UNCHANGED``. + Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. + + Returns: + Decoded image (Tensor[image_channels, image_height, image_width]) + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(decode_webp) + return torch.ops.image.decode_webp(input, mode.value) + + +def _decode_avif( + input: torch.Tensor, + mode: ImageReadMode = ImageReadMode.UNCHANGED, +) -> torch.Tensor: + """ + Decode an AVIF image into a 3 dimensional RGB[A] Tensor. + + The values of the output tensor are in uint8 in [0, 255] for most images. If + the image has a bit-depth of more than 8, then the output tensor is uint16 + in [0, 65535]. Since uint16 support is limited in pytorch, we recommend + calling :func:`torchvision.transforms.v2.functional.to_dtype()` with + ``scale=True`` after this function to convert the decoded image into a uint8 + or float tensor. + + Args: + input (Tensor[1]): a one dimensional contiguous uint8 tensor containing + the raw bytes of the AVIF image. + mode (ImageReadMode): The read mode used for optionally + converting the image color space. Default: ``ImageReadMode.UNCHANGED``. + Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. Returns: Decoded image (Tensor[image_channels, image_height, image_width]) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_webp) - return torch.ops.image.decode_webp(input) + return torch.ops.image.decode_avif(input, mode.value) diff --git a/torchvision/io/video_reader.py b/torchvision/io/video_reader.py index c00723a4534..505909fd984 100644 --- a/torchvision/io/video_reader.py +++ b/torchvision/io/video_reader.py @@ -7,9 +7,9 @@ from ..utils import _log_api_usage_once -from ._video_opt import _HAS_VIDEO_OPT +from ._video_opt import _HAS_CPU_VIDEO_DECODER -if _HAS_VIDEO_OPT: +if _HAS_CPU_VIDEO_DECODER: def _has_video_opt() -> bool: return True diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 34d1e101dbd..eb75f58cb7a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -66,7 +66,7 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: - """See :class:`~torchvision.transforms.v2.GrayscaleToRgb` for details.""" + """See :class:`~torchvision.transforms.v2.RGB` for details.""" if torch.jit.is_scripting(): return grayscale_to_rgb_image(inpt)