Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-780] Fix exception handling bug #12051

Merged
merged 9 commits into from
Sep 27, 2018
229 changes: 122 additions & 107 deletions src/io/iter_image_det_recordio.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@ class ImageDetRecordIOParser {
std::unique_ptr<ImageDetLabelMap> label_map_;
/*! \brief temp space */
mshadow::TensorContainer<cpu, 3> img_;
/*! \brief OMPException obj to store and rethrow exceptions from omp blocks*/
dmlc::OMPException omp_exc_;
};

template<typename DType>
template <typename DType>
inline void ImageDetRecordIOParser<DType>::Init(
const std::vector<std::pair<std::string, std::string> >& kwargs) {
#if MXNET_USE_OPENCV
Expand Down Expand Up @@ -282,8 +284,9 @@ inline void ImageDetRecordIOParser<DType>::Init(
<< ", use " << threadget << " threads for decoding..";
}
source_.reset(dmlc::InputSplit::Create(
param_.path_imgrec.c_str(), param_.part_index,
param_.num_parts, "recordio"));
param_.path_imgrec.c_str(),
param_.part_index, param_.num_parts,
"recordio"));

// estimate padding width for labels
int max_label_width = 0;
Expand All @@ -295,38 +298,41 @@ inline void ImageDetRecordIOParser<DType>::Init(
while (source_->NextChunk(&chunk)) {
#pragma omp parallel num_threads(param_.preprocess_threads)
{
CHECK(omp_get_num_threads() == param_.preprocess_threads);
int max_width = 0;
int tid = omp_get_thread_num();
dmlc::RecordIOChunkReader reader(chunk, tid, param_.preprocess_threads);
ImageRecordIO rec;
dmlc::InputSplit::Blob blob;
while (reader.NextRecord(&blob)) {
rec.Load(blob.dptr, blob.size);
if (rec.label != nullptr) {
if (param_.label_width > 0) {
CHECK_EQ(param_.label_width, rec.num_label)
<< "rec file provide " << rec.num_label << "-dimensional label "
"but label_width is set to " << param_.label_width;
omp_exc_.Run([&] {
CHECK(omp_get_num_threads() == param_.preprocess_threads);
int max_width = 0;
int tid = omp_get_thread_num();
dmlc::RecordIOChunkReader reader(chunk, tid,
param_.preprocess_threads);
ImageRecordIO rec;
dmlc::InputSplit::Blob blob;
while (reader.NextRecord(&blob)) {
rec.Load(blob.dptr, blob.size);
if (rec.label != nullptr) {
if (param_.label_width > 0) {
CHECK_EQ(param_.label_width, rec.num_label)
<< "rec file provide " << rec.num_label << "-dimensional label "
"but label_width is set to " << param_.label_width;
}
// update max value
max_width = std::max(max_width, rec.num_label);
} else {
LOG(FATAL) << "Not enough label packed in img_list or rec file.";
}
// update max value
max_width = std::max(max_width, rec.num_label);
} else {
LOG(FATAL) << "Not enough label packed in img_list or rec file.";
}
}
#pragma omp critical
{
max_label_width = std::max(max_label_width, max_width);
}
#pragma omp critical
{
max_label_width = std::max(max_label_width, max_width);
}
});
}
omp_exc_.Rethrow();
}
}
if (max_label_width > param_.label_pad_width) {
if (param_.label_pad_width > 0) {
LOG(FATAL) << "ImageDetRecordIOParser: label_pad_width: "
<< param_.label_pad_width << " smaller than estimated width: "
<< max_label_width;
<< param_.label_pad_width << " smaller than estimated width: " << max_label_width;
}
param_.label_pad_width = max_label_width;
}
Expand All @@ -336,19 +342,20 @@ inline void ImageDetRecordIOParser<DType>::Init(
}

source_.reset(dmlc::InputSplit::Create(
param_.path_imgrec.c_str(), param_.part_index,
param_.num_parts, "recordio"));
param_.path_imgrec.c_str(),
param_.part_index, param_.num_parts,
"recordio"));

if (param_.shuffle_chunk_size > 0) {
if (param_.shuffle_chunk_size > 4096) {
LOG(INFO) << "Chunk size: " << param_.shuffle_chunk_size
<< " MB which is larger than 4096 MB, please set "
"smaller chunk size";
<< " MB which is larger than 4096 MB, please set "
"smaller chunk size";
}
if (param_.shuffle_chunk_size < 4) {
LOG(INFO) << "Chunk size: " << param_.shuffle_chunk_size
<< " MB which is less than 4 MB, please set "
"larger chunk size";
<< " MB which is less than 4 MB, please set "
"larger chunk size";
}
// 1.1 ratio is for a bit more shuffle parts to avoid boundary issue
unsigned num_shuffle_parts =
Expand Down Expand Up @@ -381,92 +388,100 @@ ParseNext(std::vector<InstVector<DType>> *out_vec) {
out_vec->resize(param_.preprocess_threads);
#pragma omp parallel num_threads(param_.preprocess_threads)
{
CHECK(omp_get_num_threads() == param_.preprocess_threads);
int tid = omp_get_thread_num();
dmlc::RecordIOChunkReader reader(chunk, tid, param_.preprocess_threads);
ImageRecordIO rec;
dmlc::InputSplit::Blob blob;
// image data
InstVector<DType> &out = (*out_vec)[tid];
out.Clear();
while (reader.NextRecord(&blob)) {
// Opencv decode and augments
cv::Mat res;
rec.Load(blob.dptr, blob.size);
cv::Mat buf(1, rec.content_size, CV_8U, rec.content);
switch (param_.data_shape[0]) {
case 1:
res = cv::imdecode(buf, 0);
break;
case 3:
res = cv::imdecode(buf, 1);
break;
case 4:
// -1 to keep the number of channel of the encoded image, and not force gray or color.
res = cv::imdecode(buf, -1);
CHECK_EQ(res.channels(), 4)
<< "Invalid image with index " << rec.image_index()
<< ". Expected 4 channels, got " << res.channels();
break;
default:
LOG(FATAL) << "Invalid output shape " << param_.data_shape;
}
const int n_channels = res.channels();
// load label before augmentations
std::vector<float> label_buf;
if (this->label_map_ != nullptr) {
label_buf = label_map_->FindCopy(rec.image_index());
} else if (rec.label != nullptr) {
if (param_.label_width > 0) {
CHECK_EQ(param_.label_width, rec.num_label)
<< "rec file provide " << rec.num_label << "-dimensional label "
"but label_width is set to " << param_.label_width;
omp_exc_.Run([&] {
CHECK(omp_get_num_threads() == param_.preprocess_threads);
int tid = omp_get_thread_num();
dmlc::RecordIOChunkReader reader(chunk, tid, param_.preprocess_threads);
ImageRecordIO rec;
dmlc::InputSplit::Blob blob;
// image data
InstVector<DType> &out = (*out_vec)[tid];
out.Clear();
while (reader.NextRecord(&blob)) {
// Opencv decode and augments
cv::Mat res;
rec.Load(blob.dptr, blob.size);
cv::Mat buf(1, rec.content_size, CV_8U, rec.content);
switch (param_.data_shape[0]) {
case 1:
res = cv::imdecode(buf, 0);
break;
case 3:
res = cv::imdecode(buf, 1);
break;
case 4:
// -1 to keep the number of channel of the encoded image, and not
// force gray or color.
res = cv::imdecode(buf, -1);
CHECK_EQ(res.channels(), 4)
<< "Invalid image with index " << rec.image_index()
<< ". Expected 4 channels, got " << res.channels();
break;
default:
LOG(FATAL) << "Invalid output shape " << param_.data_shape;
}
label_buf.assign(rec.label, rec.label + rec.num_label);
} else {
LOG(FATAL) << "Not enough label packed in img_list or rec file.";
}
for (auto& aug : this->augmenters_[tid]) {
res = aug->Process(res, &label_buf, this->prnds_[tid].get());
}
out.Push(static_cast<unsigned>(rec.image_index()),
mshadow::Shape3(n_channels, param_.data_shape[1], param_.data_shape[2]),
mshadow::Shape1(param_.label_pad_width + 4));
const int n_channels = res.channels();
// load label before augmentations
std::vector<float> label_buf;
if (this->label_map_ != nullptr) {
label_buf = label_map_->FindCopy(rec.image_index());
} else if (rec.label != nullptr) {
if (param_.label_width > 0) {
CHECK_EQ(param_.label_width, rec.num_label)
<< "rec file provide " << rec.num_label
<< "-dimensional label "
"but label_width is set to "
<< param_.label_width;
}
label_buf.assign(rec.label, rec.label + rec.num_label);
} else {
LOG(FATAL) << "Not enough label packed in img_list or rec file.";
}
for (auto &aug : this->augmenters_[tid]) {
res = aug->Process(res, &label_buf, this->prnds_[tid].get());
}
out.Push(static_cast<unsigned>(rec.image_index()),
mshadow::Shape3(n_channels, param_.data_shape[1],
param_.data_shape[2]),
mshadow::Shape1(param_.label_pad_width + 4));

mshadow::Tensor<cpu, 3, DType> data = out.data().Back();
mshadow::Tensor<cpu, 3, DType> data = out.data().Back();

// For RGB or RGBA data, swap the B and R channel:
// OpenCV store as BGR (or BGRA) and we want RGB (or RGBA)
std::vector<int> swap_indices;
if (n_channels == 1) swap_indices = {0};
if (n_channels == 3) swap_indices = {2, 1, 0};
if (n_channels == 4) swap_indices = {2, 1, 0, 3};
// For RGB or RGBA data, swap the B and R channel:
// OpenCV store as BGR (or BGRA) and we want RGB (or RGBA)
std::vector<int> swap_indices;
if (n_channels == 1) swap_indices = {0};
if (n_channels == 3) swap_indices = {2, 1, 0};
if (n_channels == 4) swap_indices = {2, 1, 0, 3};

for (int i = 0; i < res.rows; ++i) {
uchar* im_data = res.ptr<uchar>(i);
for (int j = 0; j < res.cols; ++j) {
for (int k = 0; k < n_channels; ++k) {
for (int i = 0; i < res.rows; ++i) {
uchar *im_data = res.ptr<uchar>(i);
for (int j = 0; j < res.cols; ++j) {
for (int k = 0; k < n_channels; ++k) {
data[k][i][j] = im_data[swap_indices[k]];
}
im_data += n_channels;
}
im_data += n_channels;
}
mshadow::Tensor<cpu, 1> label = out.label().Back();
label = param_.label_pad_value;
// store info for real data_shape and label_width
label[0] = res.channels();
label[1] = res.rows;
label[2] = res.cols;
label[3] = label_buf.size();
mshadow::Copy(
label.Slice(4, 4 + label_buf.size()),
mshadow::Tensor<cpu, 1>(dmlc::BeginPtr(label_buf),
mshadow::Shape1(label_buf.size())));
res.release();
}
mshadow::Tensor<cpu, 1> label = out.label().Back();
label = param_.label_pad_value;
// store info for real data_shape and label_width
label[0] = res.channels();
label[1] = res.rows;
label[2] = res.cols;
label[3] = label_buf.size();
mshadow::Copy(label.Slice(4, 4 + label_buf.size()),
mshadow::Tensor<cpu, 1>(dmlc::BeginPtr(label_buf),
mshadow::Shape1(label_buf.size())));
res.release();
}
});
}
#else
LOG(FATAL) << "Opencv is needed for image decoding and augmenting.";
#endif
omp_exc_.Rethrow();
return true;
}

Expand Down
8 changes: 8 additions & 0 deletions src/io/iter_image_recordio_2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class ImageRecordIOParser2 {
bool legacy_shuffle_;
// whether mean image is ready.
bool meanfile_ready_;
/*! \brief OMPException obj to store and rethrow exceptions from omp blocks*/
dmlc::OMPException omp_exc_;
};

template<typename DType>
Expand Down Expand Up @@ -331,6 +333,7 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
// Copy
#pragma omp parallel for num_threads(param_.preprocess_threads)
for (int i = 0; i < n_to_copy; ++i) {
omp_exc_.Run([&] {
std::pair<unsigned, unsigned> place = inst_order_[inst_index_ + i];
const DataInst& batch = temp_[place.first][place.second];
for (unsigned j = 0; j < batch.data.size(); ++j) {
Expand All @@ -342,7 +345,9 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
batch.data[j].get_with_shape<cpu, 1, dtype>(mshadow::Shape1(unit_size_[j])));
});
}
});
}
omp_exc_.Rethrow();
n_to_out = n_to_copy;
inst_index_ += n_to_copy;
}
Expand Down Expand Up @@ -486,6 +491,7 @@ inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t
unsigned gl_idx = current_size;
#pragma omp parallel num_threads(param_.preprocess_threads)
{
omp_exc_.Run([&] {
CHECK(omp_get_num_threads() == param_.preprocess_threads);
unsigned int tid = omp_get_thread_num();
// dmlc::RecordIOChunkReader reader(*chunk, tid, param_.preprocess_threads);
Expand Down Expand Up @@ -603,7 +609,9 @@ inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t
mshadow::Shape1(label_buf.size())));
res.release();
}
});
}
omp_exc_.Rethrow();
return (std::min(batch_param_.batch_size, gl_idx) - current_size);
#else
LOG(FATAL) << "Opencv is needed for image decoding and augmenting.";
Expand Down
19 changes: 19 additions & 0 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ def test_Cifar10Rec():
for i in range(10):
assert(labelcount[i] == 5000)

def test_image_iter_exception():
def check_cifar10_exception():
get_cifar10()
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
rand_crop=False,
and_mirror=False,
shuffle=False,
data_shape=(5, 28, 28),
batch_size=100,
preprocess_threads=4,
prefetch_buffer=1)
labelcount = [0 for i in range(10)]
batchcount = 0
for batch in dataiter:
pass
assertRaises(MXNetError, check_cifar10_exception)

def test_NDArrayIter():
data = np.ones([1000, 2, 2])
Expand Down Expand Up @@ -435,3 +453,4 @@ def test_ImageRecordIter_seed_augmentation():
test_NDArrayIter_csr()
test_CSVIter()
test_ImageRecordIter_seed_augmentation()
test_image_iter_exception()