-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for 16 bits png images #4657
Changes from 9 commits
29e1b68
3a51b57
ec47775
e1568d3
7b14f34
9cd72a6
2d298d9
67ea116
6df0105
3fa9c2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,11 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { | |
} | ||
#else | ||
|
||
bool is_little_endian() { | ||
uint32_t x = 1; | ||
return *(uint8_t*)&x; | ||
} | ||
|
||
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { | ||
// Check that the input tensor dtype is uint8 | ||
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); | ||
|
@@ -72,9 +77,9 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { | |
TORCH_CHECK(retval == 1, "Could read image metadata from content.") | ||
} | ||
|
||
if (bit_depth > 8) { | ||
if (bit_depth > 16) { | ||
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); | ||
TORCH_CHECK(false, "At most 8-bit PNG images are supported currently.") | ||
TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.") | ||
} | ||
|
||
int channels = png_get_channels(png_ptr, info_ptr); | ||
|
@@ -168,15 +173,46 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { | |
png_read_update_info(png_ptr, info_ptr); | ||
} | ||
|
||
auto tensor = | ||
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); | ||
auto ptr = tensor.accessor<uint8_t, 3>().data(); | ||
for (int pass = 0; pass < number_of_passes; pass++) { | ||
for (png_uint_32 i = 0; i < height; ++i) { | ||
png_read_row(png_ptr, ptr, nullptr); | ||
ptr += width * channels; | ||
auto num_pixels_per_row = width * channels; | ||
auto tensor = torch::empty( | ||
{int64_t(height), int64_t(width), channels}, | ||
bit_depth <= 8 ? torch::kU8 : torch::kI32); | ||
|
||
if (bit_depth <= 8) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this |
||
auto t_ptr = tensor.accessor<uint8_t, 3>().data(); | ||
for (int pass = 0; pass < number_of_passes; pass++) { | ||
for (png_uint_32 i = 0; i < height; ++i) { | ||
png_read_row(png_ptr, t_ptr, nullptr); | ||
t_ptr += num_pixels_per_row; | ||
} | ||
t_ptr = tensor.accessor<uint8_t, 3>().data(); | ||
} | ||
} else { | ||
// We're reading a 16bits png, but pytorch doesn't support uint16. | ||
// So we read each row in a 16bits tmp_buffer which we then cast into | ||
// a int32 tensor instead. | ||
if (is_little_endian()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fmassa I eventually realized that this was a much cleaner and simpler way to handle the endianness. The rest takes care of itself when we cast the uint16 value into a int32_t a few lines below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! |
||
png_set_swap(png_ptr); | ||
} | ||
int32_t* t_ptr = tensor.accessor<int32_t, 3>().data(); | ||
|
||
// We create a tensor instead of malloc-ing for automatic memory management | ||
auto tmp_buffer_tensor = torch::empty( | ||
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8); | ||
uint16_t* tmp_buffer = | ||
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit because it was already like this before: you can just do |
||
|
||
for (int pass = 0; pass < number_of_passes; pass++) { | ||
for (png_uint_32 i = 0; i < height; ++i) { | ||
png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr); | ||
// Now we copy the uint16 values into the int32 tensor. | ||
for (size_t j = 0; j < num_pixels_per_row; ++j) { | ||
t_ptr[j] = (int32_t)tmp_buffer[j]; | ||
} | ||
t_ptr += num_pixels_per_row; | ||
} | ||
t_ptr = tensor.accessor<int32_t, 3>().data(); | ||
} | ||
ptr = tensor.accessor<uint8_t, 3>().data(); | ||
} | ||
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); | ||
return tensor.permute({2, 0, 1}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stolen from https://github.com/pytorch/pytorch/blob/4c4525fa5cffb924d0e9b844449e6bd0a0df4aff/torch/csrc/utils/byte_order.cpp#L118