Skip to content

Commit

Permalink
Merge branch 'main' of github.com:pytorch/vision into avif_release
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 28, 2024
2 parents 61b6d20 + a59c939 commit 23057aa
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 31 deletions.
88 changes: 87 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def test_decode_gif_webp_errors(decode_fun):
if decode_fun is decode_gif:
expected_match = re.escape("DGifOpenFileName() failed - 103")
elif decode_fun is decode_webp:
expected_match = "WebPDecodeRGB failed."
expected_match = "WebPGetFeatures failed."
else:
expected_match = "avifDecoderParse failed: BMFF parsing failed"
with pytest.raises(RuntimeError, match=expected_match):
Expand All @@ -893,6 +893,31 @@ def test_decode_webp(decode_fun, scripted):
assert img[None].is_contiguous(memory_format=torch.channels_last)


# This test is skipped because it requires webp images that we're not including
# within the repo. The test images were downloaded from the different pages of
# https://developers.google.com/speed/webp/gallery
# Note that converting an RGBA image to RGB leads to bad results because the
# transparent pixels aren't necessarily set to "black" or "white", they can be
# random stuff. This is consistent with PIL results.
@pytest.mark.skip(reason="Need to download test images first")
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None))
)
@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp"))
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
encoded_bytes = read_file(filename)
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes, mode=mode)
assert img[None].is_contiguous(memory_format=torch.channels_last)

pil_img = Image.open(filename).convert(pil_mode)
from_pil = F.pil_to_tensor(pil_img)
assert_equal(img, from_pil)


@pytest.mark.parametrize("decode_fun", (decode_avif, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_avif(decode_fun, scripted):
Expand All @@ -904,5 +929,66 @@ def test_decode_avif(decode_fun, scripted):
assert img[None].is_contiguous(memory_format=torch.channels_last)


# Run on avif files from https://github.com/AOMediaCodec/libavif/tree/main/tests/data
@pytest.mark.skip(reason="Need to download test images first")
# Note: decode_image fails because some of these files have a (valid) signature
# we don't recognize. We should probably use libmagic....
# @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("decode_fun", (decode_avif,))
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode",
(
(ImageReadMode.RGB, "RGB"),
(ImageReadMode.RGB_ALPHA, "RGBA"),
(ImageReadMode.UNCHANGED, None),
),
)
@pytest.mark.parametrize("filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"))
def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename):
if "reversed_dimg_order" in str(filename):
# Pillow properly decodes this one, but we don't (order of parts of the
# image is wrong). This is due to a bug that was recently fixed in
# libavif. Hopefully this test will end up passing soon with a new
# libavif version https://github.com/AOMediaCodec/libavif/issues/2311
pytest.xfail()
import pillow_avif # noqa

encoded_bytes = read_file(filename)
if scripted:
decode_fun = torch.jit.script(decode_fun)
try:
img = decode_fun(encoded_bytes, mode=mode)
except RuntimeError as e:
if any(
s in str(e)
for s in ("BMFF parsing failed", "avifDecoderParse failed: ", "file contains more than one image")
):
pytest.skip(reason="Expected failure, that's OK")
else:
raise e
assert img[None].is_contiguous(memory_format=torch.channels_last)
if mode == ImageReadMode.RGB:
assert img.shape[0] == 3
if mode == ImageReadMode.RGB_ALPHA:
assert img.shape[0] == 4
if img.dtype == torch.uint16:
img = F.to_dtype(img, dtype=torch.uint8, scale=True)

from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
if False:
from torchvision.utils import make_grid

g = make_grid([img, from_pil])
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))
if mode != ImageReadMode.RGB:
# We don't compare against PIL for RGB because results look pretty
# different on RGBA images (other images are fine). The result on
# torchvision basically just plainly ignores the alpha channel, resuting
# in transparent pixels looking dark. PIL seems to be using a sort of
# k-nn thing, looking at the output. Take a look at the resuting images.
torch.testing.assert_close(img, from_pil, rtol=0, atol=3)


if __name__ == "__main__":
pytest.main([__file__])
43 changes: 33 additions & 10 deletions torchvision/csrc/io/image/cpu/decode_avif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ namespace vision {
namespace image {

#if !AVIF_FOUND
torch::Tensor decode_avif(const torch::Tensor& data) {
torch::Tensor decode_avif(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(
false, "decode_avif: torchvision not compiled with libavif support");
}
Expand All @@ -23,7 +25,9 @@ struct UniquePtrDeleter {
};
using DecoderPtr = std::unique_ptr<avifDecoder, UniquePtrDeleter>;

torch::Tensor decode_avif(const torch::Tensor& encoded_data) {
torch::Tensor decode_avif(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
// This is based on
// https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c
// Refer there for more detail about what each function does, and which
Expand Down Expand Up @@ -58,24 +62,43 @@ torch::Tensor decode_avif(const torch::Tensor& encoded_data) {
avifResultToString(result));
TORCH_CHECK(
decoder->imageCount == 1, "Avif file contains more than one image");
TORCH_CHECK(
decoder->image->depth <= 8,
"avif images with bitdepth > 8 are not supported");

result = avifDecoderNextImage(decoder.get());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderNextImage failed:",
avifResultToString(result));

auto out = torch::empty(
{decoder->image->height, decoder->image->width, 3}, torch::kUInt8);

avifRGBImage rgb;
memset(&rgb, 0, sizeof(rgb));
avifRGBImageSetDefaults(&rgb, decoder->image);
rgb.format = AVIF_RGB_FORMAT_RGB;
rgb.pixels = out.data_ptr<uint8_t>();

// images encoded as 10 or 12 bits will be decoded as uint16. The rest are
// decoded as uint8.
auto use_uint8 = (decoder->image->depth <= 8);
rgb.depth = use_uint8 ? 8 : 16;

if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
mode != IMAGE_READ_MODE_RGB_ALPHA) {
// Other modes aren't supported, but we don't error or even warn because we
// have generic entry points like decode_image which may support all modes,
// it just depends on the underlying decoder.
mode = IMAGE_READ_MODE_UNCHANGED;
}

// If return_rgb is false it means we return rgba - nothing else.
auto return_rgb =
(mode == IMAGE_READ_MODE_RGB ||
(mode == IMAGE_READ_MODE_UNCHANGED && !decoder->alphaPresent));

auto num_channels = return_rgb ? 3 : 4;
rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA;
rgb.ignoreAlpha = return_rgb ? AVIF_TRUE : AVIF_FALSE;

auto out = torch::empty(
{rgb.height, rgb.width, num_channels},
use_uint8 ? torch::kUInt8 : at::kUInt16);
rgb.pixels = (uint8_t*)out.data_ptr();
rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb);

result = avifImageYUVToRGB(decoder->image, &rgb);
Expand Down
5 changes: 4 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_avif.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_avif(const torch::Tensor& data);
C10_EXPORT torch::Tensor decode_avif(
const torch::Tensor& encoded_data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace image
} // namespace vision
4 changes: 2 additions & 2 deletions torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ torch::Tensor decode_image(
0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif"
TORCH_CHECK(data.numel() >= 12, err_msg);
if ((memcmp(avif_signature, datap + 4, 8) == 0)) {
return decode_avif(data);
return decode_avif(data, mode);
}

const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF"
Expand All @@ -67,7 +67,7 @@ torch::Tensor decode_image(
TORCH_CHECK(data.numel() >= 15, err_msg);
if ((memcmp(webp_signature_begin, datap, 4) == 0) &&
(memcmp(webp_signature_end, datap + 8, 7) == 0)) {
return decode_webp(data);
return decode_webp(data, mode);
}

TORCH_CHECK(false, err_msg);
Expand Down
48 changes: 41 additions & 7 deletions torchvision/csrc/io/image/cpu/decode_webp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ namespace vision {
namespace image {

#if !WEBP_FOUND
torch::Tensor decode_webp(const torch::Tensor& data) {
torch::Tensor decode_webp(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(
false, "decode_webp: torchvision not compiled with libwebp support");
}
#else

torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
torch::Tensor decode_webp(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
Expand All @@ -26,13 +30,43 @@ torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
encoded_data.dim(),
" dims.");

auto encoded_data_p = encoded_data.data_ptr<uint8_t>();
auto encoded_data_size = encoded_data.numel();

WebPBitstreamFeatures features;
auto res = WebPGetFeatures(encoded_data_p, encoded_data_size, &features);
TORCH_CHECK(
res == VP8_STATUS_OK, "WebPGetFeatures failed with error code ", res);
TORCH_CHECK(
!features.has_animation, "Animated webp files are not supported.");

if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
mode != IMAGE_READ_MODE_RGB_ALPHA) {
// Other modes aren't supported, but we don't error or even warn because we
// have generic entry points like decode_image which may support all modes,
// it just depends on the underlying decoder.
mode = IMAGE_READ_MODE_UNCHANGED;
}

// If return_rgb is false it means we return rgba - nothing else.
auto return_rgb =
(mode == IMAGE_READ_MODE_RGB ||
(mode == IMAGE_READ_MODE_UNCHANGED && !features.has_alpha));

auto decoding_func = return_rgb ? WebPDecodeRGB : WebPDecodeRGBA;
auto num_channels = return_rgb ? 3 : 4;

int width = 0;
int height = 0;
auto decoded_data = WebPDecodeRGB(
encoded_data.data_ptr<uint8_t>(), encoded_data.numel(), &width, &height);
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed.");
auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8);
return out.permute({2, 0, 1}); // return CHW, channels-last

auto decoded_data =
decoding_func(encoded_data_p, encoded_data_size, &width, &height);
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed.");

auto out = torch::from_blob(
decoded_data, {height, width, num_channels}, torch::kUInt8);

return out.permute({2, 0, 1});
}
#endif // WEBP_FOUND

Expand Down
5 changes: 4 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_webp.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data);
C10_EXPORT torch::Tensor decode_webp(
const torch::Tensor& encoded_data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace image
} // namespace vision
6 changes: 4 additions & 2 deletions torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ static auto registry =
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_jpeg)
.op("image::decode_webp", &decode_webp)
.op("image::decode_avif", &decode_avif)
.op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor",
&decode_webp)
.op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor",
&decode_avif)
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/moving_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
def __getitem__(self, idx: int) -> torch.Tensor:
"""
Args:
index (int): Index
idx (int): Index
Returns:
torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
"""
Expand Down
39 changes: 34 additions & 5 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class ImageReadMode(Enum):
``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
RGB with transparency.
.. note::
Some decoders won't support all possible values, e.g. a decoder may only
support "RGB" and "RGBA" mode.
"""

UNCHANGED = 0
Expand Down Expand Up @@ -365,28 +370,52 @@ def decode_gif(input: torch.Tensor) -> torch.Tensor:

def decode_webp(
input: torch.Tensor,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
) -> torch.Tensor:
"""
Decode a WEBP image into a 3 dimensional RGB Tensor.
Decode a WEBP image into a 3 dimensional RGB[A] Tensor.
The values of the output tensor are uint8 between 0 and 255. If the input
image is RGBA, the transparency is ignored.
The values of the output tensor are uint8 between 0 and 255.
Args:
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
the raw bytes of the WEBP image.
mode (ImageReadMode): The read mode used for optionally
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
Returns:
Decoded image (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
return torch.ops.image.decode_webp(input)
return torch.ops.image.decode_webp(input, mode.value)


def decode_avif(
input: torch.Tensor,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
) -> torch.Tensor:
"""
Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
The values of the output tensor are in uint8 in [0, 255] for most images. If
the image has a bit-depth of more than 8, then the output tensor is uint16
in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
``scale=True`` after this function to convert the decoded image into a uint8
or float tensor.
Args:
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
the raw bytes of the AVIF image.
mode (ImageReadMode): The read mode used for optionally
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
Returns:
Decoded image (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
return torch.ops.image.decode_avif(input)
return torch.ops.image.decode_avif(input, mode.value)
Loading

0 comments on commit 23057aa

Please sign in to comment.