Skip to content

Commit

Permalink
test_ImageRecordIter_seed_augmentation flaky test fix (apache#12485)
Browse files Browse the repository at this point in the history
* Moves seed_aug parameter to ImageRecParserParam and re-seeds RNG before each augmentation to guarantee reproducibilit

* Update image record iterator tests to check the whole iterator not only first image
  • Loading branch information
perdasilva authored and haohuw committed Jun 23, 2019
1 parent 48e74b1 commit d8155ce
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 31 deletions.
13 changes: 1 addition & 12 deletions src/io/image_aug_default.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
int pad;
/*! \brief shape of the image data*/
TShape data_shape;
/*! \brief random seed for augmentations */
dmlc::optional<int> seed_aug;

// declare parameters
DMLC_DECLARE_PARAMETER(DefaultImageAugmentParam) {
Expand Down Expand Up @@ -188,8 +186,6 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
DMLC_DECLARE_FIELD(pad).set_default(0)
.describe("Change size from ``[width, height]`` into "
"``[pad + width + pad, pad + height + pad]`` by padding pixes");
DMLC_DECLARE_FIELD(seed_aug).set_default(dmlc::optional<int>())
.describe("Random seed for augmentations.");
}
};

Expand All @@ -208,9 +204,7 @@ std::vector<dmlc::ParamFieldInfo> ListDefaultAugParams() {
class DefaultImageAugmenter : public ImageAugmenter {
public:
// contructor
DefaultImageAugmenter() {
seed_init_state = false;
}
DefaultImageAugmenter() {}
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
std::vector<std::pair<std::string, std::string> > kwargs_left;
kwargs_left = param_.InitAllowUnknown(kwargs);
Expand Down Expand Up @@ -250,10 +244,6 @@ class DefaultImageAugmenter : public ImageAugmenter {
}
cv::Mat Process(const cv::Mat &src, std::vector<float> *label,
common::RANDOM_ENGINE *prnd) override {
if (!seed_init_state && param_.seed_aug.has_value()) {
prnd->seed(param_.seed_aug.value());
seed_init_state = true;
}
using mshadow::index_t;
bool is_cropped = false;

Expand Down Expand Up @@ -558,7 +548,6 @@ class DefaultImageAugmenter : public ImageAugmenter {
DefaultImageAugmentParam param_;
/*! \brief list of possible rotate angle */
std::vector<int> rotate_list_;
bool seed_init_state;
};

ImageAugmenter* ImageAugmenter::Create(const std::string& name) {
Expand Down
4 changes: 4 additions & 0 deletions src/io/image_iter_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
size_t shuffle_chunk_size;
/*! \brief the seed for chunk shuffling*/
int shuffle_chunk_seed;
/*! \brief random seed for augmentations */
dmlc::optional<int> seed_aug;

// declare parameters
DMLC_DECLARE_PARAMETER(ImageRecParserParam) {
Expand Down Expand Up @@ -165,6 +167,8 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
.describe("The data shuffle buffer size in MB. Only valid if shuffle is true.");
DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0)
.describe("The random seed for shuffling");
DMLC_DECLARE_FIELD(seed_aug).set_default(dmlc::optional<int>())
.describe("Random seed for augmentations.");
}
};

Expand Down
7 changes: 7 additions & 0 deletions src/io/iter_image_recordio_2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,13 @@ inline size_t ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t*
cv::Mat res;
rec.Load(blob.dptr, blob.size);
cv::Mat buf(1, rec.content_size, CV_8U, rec.content);

// If augmentation seed is supplied
// Re-seed RNG to guarantee reproducible results
if (param_.seed_aug.has_value()) {
prnds_[tid]->seed(idx + param_.seed_aug.value() + kRandMagic);
}

switch (param_.data_shape[0]) {
case 1:
#if MXNET_USE_LIBJPEG_TURBO
Expand Down
79 changes: 60 additions & 19 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
import sys
from common import assertRaises
import unittest
try:
from itertools import izip_longest as zip_longest
except:
from itertools import zip_longest


def test_MNISTIter():
Expand Down Expand Up @@ -427,13 +431,56 @@ def check_CSVIter_synthetic(dtype='float32'):
for dtype in ['int32', 'int64', 'float32']:
check_CSVIter_synthetic(dtype=dtype)

@unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11359")
def test_ImageRecordIter_seed_augmentation():
get_cifar10()
seed_aug = 3

def assert_dataiter_items_equals(dataiter1, dataiter2):
"""
Asserts that two data iterators have the same numbner of batches,
that the batches have the same number of items, and that the items
are the equal.
"""
for batch1, batch2 in zip_longest(dataiter1, dataiter2):

# ensure iterators contain the same number of batches
# zip_longest will return None if on of the iterators have run out of batches
assert batch1 and batch2, 'The iterators do not contain the same number of batches'

# ensure batches are of same length
assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length'

# ensure batch data is the same
for i in range(0, len(batch1.data)):
data1 = batch1.data[i].asnumpy().astype(np.uint8)
data2 = batch2.data[i].asnumpy().astype(np.uint8)
assert(np.array_equal(data1, data2))

def assert_dataiter_items_not_equals(dataiter1, dataiter2):
"""
Asserts that two data iterators have the same numbner of batches,
that the batches have the same number of items, and that the items
are the _not_ equal.
"""
for batch1, batch2 in zip_longest(dataiter1, dataiter2):

# ensure iterators are of same length
# zip_longest will return None if on of the iterators have run out of batches
assert batch1 and batch2, 'The iterators do not contain the same number of batches'

# ensure batches are of same length
assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length'

# ensure batch data is the same
for i in range(0, len(batch1.data)):
data1 = batch1.data[i].asnumpy().astype(np.uint8)
data2 = batch2.data[i].asnumpy().astype(np.uint8)
if not np.array_equal(data1, data2):
return
assert False, 'Expected data iterators to be different, but they are the same'

# check whether to get constant images after fixing seed_aug
dataiter = mx.io.ImageRecordIter(
dataiter1 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -449,10 +496,8 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -468,12 +513,12 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

assert_dataiter_items_equals(dataiter1, dataiter2)

# check whether to get different images after change seed_aug
dataiter = mx.io.ImageRecordIter(
dataiter1.reset()
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -489,31 +534,27 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug+1)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
assert(not np.array_equal(data,data2))

assert_dataiter_items_not_equals(dataiter1, dataiter2)

# check whether seed_aug changes the iterator behavior
dataiter = mx.io.ImageRecordIter(
dataiter1 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

assert_dataiter_items_equals(dataiter1, dataiter2)

if __name__ == "__main__":
test_NDArrayIter()
Expand Down

0 comments on commit d8155ce

Please sign in to comment.