Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parsing of JPEG headers #175

Merged
merged 2 commits into from
Sep 20, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 20 additions & 28 deletions dali/image/jpeg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,40 +54,32 @@ void PrintSubsampling(int sampling) {
}
#endif // DALI_USE_JPEG_TURBO

// Slightly modified from https://github.com/apache/incubator-mxnet/blob/master/plugin/opencv/cv_api.cc
// http://www.64lines.com/jpeg-width-height
// Gets the JPEG size from the array of data passed to the function, file reference: http://www.obrador.com/essentialjpeg/headerinfo.htm
// Based on https://github.com/scardine/image_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code seems only very slightly modified from the previous version, so I would put both references here (image_size and mxnet)

bool get_jpeg_size(const uint8 *data, size_t data_size, int *height, int *width) {
// Check for valid JPEG image
unsigned int i = 0; // Keeps track of the position within the file
if (data[i] == 0xFF && data[i+1] == 0xD8 && data[i+2] == 0xFF && data[i+3] == 0xE0) {
if (data[i] == 0xFF && data[i+1] == 0xD8) {
i += 4;
// Check for valid JPEG header (null terminated JFIF)
if (data[i+2] == 'J' && data[i+3] == 'F' && data[i+4] == 'I'
&& data[i+5] == 'F' && data[i+6] == 0x00) {
// Retrieve the block length of the first block since
// the first block will not contain the size of file
uint16_t block_length = data[i] * 256 + data[i+1];
while (i < data_size) {
i+=block_length; // Increase the file index to get to the next block
if (i >= data_size) return false; // Check to protect against segmentation faults
if (data[i] != 0xFF) return false; // Check that we are truly at the start of another block
if (data[i+1] == 0xC0) {
// 0xFFC0 is the "Start of frame" marker which contains the file size
// The structure of the 0xFFC0 block is quite simple
// [0xFFC0][ushort length][uchar precision][ushort x][ushort y]
*height = data[i+5]*256 + data[i+6];
*width = data[i+7]*256 + data[i+8];
return true;
} else {
i+=2; // Skip the block marker
block_length = data[i] * 256 + data[i+1]; // Go to the next block
}
// Retrieve the block length of the first block since
// the first block will not contain the size of file
uint16_t block_length = data[i] * 256 + data[i+1];
while (i < data_size) {
i+=block_length; // Increase the file index to get to the next block
if (i >= data_size) return false; // Check to protect against segmentation faults
if (data[i] != 0xFF) return false; // Check that we are truly at the start of another block
if (data[i+1] >= 0xC0 && data[i+1] <= 0xC3) {
// 0xFFC0 is the "Start of frame" marker which contains the file size
// The structure of the 0xFFC0 block is quite simple
// [0xFFC0][ushort length][uchar precision][ushort x][ushort y]
*height = data[i+5]*256 + data[i+6];
*width = data[i+7]*256 + data[i+8];
return true;
} else {
i+=2; // Skip the block marker
block_length = data[i] * 256 + data[i+1]; // Go to the next block
}
return false; // If this point is reached then no size was found
} else {
return false; // Not a valid JFIF string
}
return false; // If this point is reached then no size was found
} else {
return false; // Not a valid SOI header
}
Expand Down
8 changes: 5 additions & 3 deletions dali/test/python/test_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ class CommonPipeline(Pipeline):
def __init__(self, batch_size, num_threads, device_id):
super(CommonPipeline, self).__init__(batch_size, num_threads, device_id)

self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB)
self.decode_gpu = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB)
self.decode_host = ops.HostDecoder(device = "cpu", output_type = types.RGB)

def base_define_graph(self, inputs, labels):
images = self.decode(inputs)
return (images, labels)
images_gpu = self.decode_gpu(inputs)
images_host = self.decode_host(inputs)
return (images_gpu, images_host, labels)

class MXNetReaderPipeline(CommonPipeline):
def __init__(self, batch_size, num_threads, device_id, num_gpus, data_paths):
Expand Down