Skip to content

Commit

Permalink
[fbsync] Add support for 16 bits png images (#4657)
Browse files Browse the repository at this point in the history
Summary:
* WIP

* cleaner code

* Add tests

* Add docs

* Assert dtype

* put back check

* Address comments

Reviewed By: NicolasHug

Differential Revision: D31916334

fbshipit-source-id: 8877266f6e533e8c45c5f202e535944a9a939376

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Oct 26, 2021
1 parent be3ef03 commit 83d6f0e
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 13 deletions.
Binary file added test/assets/fakedata/logos/rgb_pytorch16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/fakedata/logos/rgbalpha_pytorch16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def test_decode_png(img_path, pil_mode, mode):
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
img_lpng, img_pil = img_lpng[0], img_pil[0]

if "16" in img_path:
# PIL converts 16 bits pngs in uint8
assert img_lpng.dtype == torch.int32
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)

torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)


Expand Down
56 changes: 46 additions & 10 deletions torchvision/csrc/io/image/cpu/decode_png.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
}
#else

bool is_little_endian() {
uint32_t x = 1;
return *(uint8_t*)&x;
}

torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
Expand Down Expand Up @@ -72,9 +77,9 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}

if (bit_depth > 8) {
if (bit_depth > 16) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "At most 8-bit PNG images are supported currently.")
TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.")
}

int channels = png_get_channels(png_ptr, info_ptr);
Expand Down Expand Up @@ -168,15 +173,46 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
png_read_update_info(png_ptr, info_ptr);
}

auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += width * channels;
auto num_pixels_per_row = width * channels;
auto tensor = torch::empty(
{int64_t(height), int64_t(width), channels},
bit_depth <= 8 ? torch::kU8 : torch::kI32);

if (bit_depth <= 8) {
auto t_ptr = tensor.accessor<uint8_t, 3>().data();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, t_ptr, nullptr);
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<uint8_t, 3>().data();
}
} else {
// We're reading a 16bits png, but pytorch doesn't support uint16.
// So we read each row in a 16bits tmp_buffer which we then cast into
// a int32 tensor instead.
if (is_little_endian()) {
png_set_swap(png_ptr);
}
int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();

// We create a tensor instead of malloc-ing for automatic memory management
auto tmp_buffer_tensor = torch::empty(
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
uint16_t* tmp_buffer =
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();

for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
// Now we copy the uint16 values into the int32 tensor.
for (size_t j = 0; j < num_pixels_per_row; ++j) {
t_ptr[j] = (int32_t)tmp_buffer[j];
}
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<int32_t, 3>().data();
}
ptr = tensor.accessor<uint8_t, 3>().data();
}
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1});
Expand Down
13 changes: 10 additions & 3 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
"""
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
.. warning::
Should pytorch ever support the uint16 dtype natively, the dtype of the
output for 16-bits pngs will be updated from int32 to uint16.
Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
Expand Down Expand Up @@ -188,7 +193,8 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
Expand All @@ -209,7 +215,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
Args:
path (str): path of the JPEG or PNG image.
Expand Down

0 comments on commit 83d6f0e

Please sign in to comment.