Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 16 bits png images #4657

Merged
merged 10 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test/assets/fakedata/logos/rgb_pytorch16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/fakedata/logos/rgbalpha_pytorch16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def test_decode_png(img_path, pil_mode, mode):
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
img_lpng, img_pil = img_lpng[0], img_pil[0]

if "16" in img_path:
# PIL converts 16 bits pngs in uint8
assert img_lpng.dtype == torch.int32
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)

torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)


Expand Down
56 changes: 46 additions & 10 deletions torchvision/csrc/io/image/cpu/decode_png.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
}
#else

bool is_little_endian() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint32_t x = 1;
return *(uint8_t*)&x;
}

torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
Expand Down Expand Up @@ -72,9 +77,9 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}

if (bit_depth > 8) {
if (bit_depth > 16) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "At most 8-bit PNG images are supported currently.")
TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.")
}

int channels = png_get_channels(png_ptr, info_ptr);
Expand Down Expand Up @@ -168,15 +173,46 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
png_read_update_info(png_ptr, info_ptr);
}

auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += width * channels;
auto num_pixels_per_row = width * channels;
auto tensor = torch::empty(
{int64_t(height), int64_t(width), channels},
bit_depth <= 8 ? torch::kU8 : torch::kI32);

if (bit_depth <= 8) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this if block is unchanged and corresponds to the original code. I just renamed ptr into t_ptr, because the other block uses too many pointers for ptr to be explicit enough

auto t_ptr = tensor.accessor<uint8_t, 3>().data();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, t_ptr, nullptr);
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<uint8_t, 3>().data();
}
} else {
// We're reading a 16bits png, but pytorch doesn't support uint16.
// So we read each row in a 16bits tmp_buffer which we then cast into
// a int32 tensor instead.
if (is_little_endian()) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa I eventually realized that this was a much cleaner and simpler way to handle the endianness. The rest takes care of itself when we cast the uint16 value into a int32_t a few lines below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

png_set_swap(png_ptr);
}
int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();

// We create a tensor instead of malloc-ing for automatic memory management
auto tmp_buffer_tensor = torch::empty(
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
uint16_t* tmp_buffer =
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit because it was already like this before: you can just do tmp_buffer_tensor.data_ptr<uint8_t, 1>()


for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
// Now we copy the uint16 values into the int32 tensor.
for (size_t j = 0; j < num_pixels_per_row; ++j) {
t_ptr[j] = (int32_t)tmp_buffer[j];
}
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<int32_t, 3>().data();
}
ptr = tensor.accessor<uint8_t, 3>().data();
}
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1});
Expand Down
13 changes: 10 additions & 3 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].

.. warning::
Should pytorch ever support the uint16 dtype natively, the dtype of the
output for 16-bits pngs will be updated from int32 to uint16.

Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
Expand Down Expand Up @@ -188,7 +193,8 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
operation to decode the image into a 3 dimensional RGB Tensor.

Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].

Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
Expand All @@ -209,7 +215,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].

Args:
path (str): path of the JPEG or PNG image.
Expand Down