From 383e33f0ac4e696702ac2053df33f2e36623e210 Mon Sep 17 00:00:00 2001 From: Tong He Date: Tue, 19 Jun 2018 15:23:19 -0700 Subject: [PATCH] Add standard ResNet data augmentation for ImageRecordIter (#11027) * add resnet augmentation * add test * fix scope * fix warning * fix lint * fix lint * add color jitter and pca noise * fix center crop * merge * fix lint * Trigger CI * fix * fix augmentation implementation * add checks for parameters * modify training script * fix compile error * Trigger CI * Trigger CI * modify error message * Trigger CI * Trigger CI * Trigger CI * improve script in example * fix script * clear code * Trigger CI * set min_aspect_ratio to optional, move rotation and pad before random resized crop * fix * Trigger CI * Trigger CI * Trigger CI * fix default values * Trigger CI --- example/image-classification/common/data.py | 48 +++- .../image-classification/train_imagenet.py | 4 +- src/io/image_aug_default.cc | 241 ++++++++++++++++-- tests/python/train/test_resnet_aug.py | 173 +++++++++++++ 4 files changed, 435 insertions(+), 31 deletions(-) create mode 100644 tests/python/train/test_resnet_aug.py diff --git a/example/image-classification/common/data.py b/example/image-classification/common/data.py index 05f5ddc4506e..bfaadb3ff6b8 100755 --- a/example/image-classification/common/data.py +++ b/example/image-classification/common/data.py @@ -43,9 +43,9 @@ def add_data_args(parser): def add_data_aug_args(parser): aug = parser.add_argument_group( 'Image augmentations', 'implemented in src/io/image_aug_default.cc') - aug.add_argument('--random-crop', type=int, default=1, + aug.add_argument('--random-crop', type=int, default=0, help='if or not randomly crop the image') - aug.add_argument('--random-mirror', type=int, default=1, + aug.add_argument('--random-mirror', type=int, default=0, help='if or not randomly flip horizontally') aug.add_argument('--max-random-h', type=int, default=0, help='max change of hue, whose range is [0, 180]') @@ -53,8 +53,13 @@ def add_data_aug_args(parser): help='max change of saturation, whose range is [0, 255]') aug.add_argument('--max-random-l', type=int, default=0, help='max change of intensity, whose range is [0, 255]') + aug.add_argument('--min-random-aspect-ratio', type=float, default=None, + help='min value of aspect ratio, whose value is either None or a positive value.') aug.add_argument('--max-random-aspect-ratio', type=float, default=0, - help='max change of aspect ratio, whose range is [0, 1]') + help='max value of aspect ratio. If min_random_aspect_ratio is None, ' + 'the aspect ratio range is [1-max_random_aspect_ratio, ' + '1+max_random_aspect_ratio], otherwise it is ' + '[min_random_aspect_ratio, max_random_aspect_ratio].') aug.add_argument('--max-random-rotate-angle', type=int, default=0, help='max angle to rotate, whose range is [0, 360]') aug.add_argument('--max-random-shear-ratio', type=float, default=0, @@ -63,16 +68,28 @@ def add_data_aug_args(parser): help='max ratio to scale') aug.add_argument('--min-random-scale', type=float, default=1, help='min ratio to scale, should >= img_size/input_shape. otherwise use --pad-size') + aug.add_argument('--max-random-area', type=float, default=1, + help='max area to crop in random resized crop, whose range is [0, 1]') + aug.add_argument('--min-random-area', type=float, default=1, + help='min area to crop in random resized crop, whose range is [0, 1]') + aug.add_argument('--brightness', type=float, default=0, + help='brightness jittering, whose range is [0, 1]') + aug.add_argument('--contrast', type=float, default=0, + help='contrast jittering, whose range is [0, 1]') + aug.add_argument('--saturation', type=float, default=0, + help='saturation jittering, whose range is [0, 1]') + aug.add_argument('--pca-noise', type=float, default=0, + help='pca noise, whose range is [0, 1]') + aug.add_argument('--random-resized-crop', type=int, default=0, + help='whether to use random resized crop') return aug -def set_data_aug_level(aug, level): - if level >= 1: - aug.set_defaults(random_crop=1, random_mirror=1) - if level >= 2: - aug.set_defaults(max_random_h=36, max_random_s=50, max_random_l=50) - if level >= 3: - aug.set_defaults(max_random_rotate_angle=10, max_random_shear_ratio=0.1, max_random_aspect_ratio=0.25) - +def set_resnet_aug(aug): + # standard data augmentation setting for resnet training + aug.set_defaults(random_crop=1, random_resized_crop=1) + aug.set_defaults(min_random_area=0.08) + aug.set_defaults(max_random_aspect_ratio=4./3., min_random_aspect_ratio=3./4.) + aug.set_defaults(brightness=0.4, contrast=0.4, saturation=0.4, pca_noise=0.1) class SyntheticDataIter(DataIter): def __init__(self, num_classes, data_shape, max_iter, dtype): @@ -135,8 +152,16 @@ def get_rec_iter(args, kv=None): max_random_scale = args.max_random_scale, pad = args.pad_size, fill_value = 127, + random_resized_crop = args.random_resized_crop, min_random_scale = args.min_random_scale, max_aspect_ratio = args.max_random_aspect_ratio, + min_aspect_ratio = args.min_random_aspect_ratio, + max_random_area = args.max_random_area, + min_random_area = args.min_random_area, + brightness = args.brightness, + contrast = args.contrast, + saturation = args.saturation, + pca_noise = args.pca_noise, random_h = args.max_random_h, random_s = args.max_random_s, random_l = args.max_random_l, @@ -156,6 +181,7 @@ def get_rec_iter(args, kv=None): mean_r = rgb_mean[0], mean_g = rgb_mean[1], mean_b = rgb_mean[2], + resize = 256, data_name = 'data', label_name = 'softmax_label', batch_size = args.batch_size, diff --git a/example/image-classification/train_imagenet.py b/example/image-classification/train_imagenet.py index f465fbc5f469..a90b6aead237 100644 --- a/example/image-classification/train_imagenet.py +++ b/example/image-classification/train_imagenet.py @@ -30,8 +30,8 @@ fit.add_fit_args(parser) data.add_data_args(parser) data.add_data_aug_args(parser) - # use a large aug level - data.set_data_aug_level(parser, 3) + # uncomment to set standard augmentation for resnet training + # data.set_resnet_aug(parser) parser.set_defaults( # network network = 'resnet', diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc index 22af7d927500..f7d08b92f176 100644 --- a/src/io/image_aug_default.cc +++ b/src/io/image_aug_default.cc @@ -46,10 +46,14 @@ struct DefaultImageAugmentParam : public dmlc::Parameter min_aspect_ratio; /*! \brief random shear the image [-max_shear_ratio, max_shear_ratio] */ float max_shear_ratio; /*! \brief max crop size */ @@ -58,12 +62,24 @@ struct DefaultImageAugmentParam : public dmlc::Parameter()) .describe("Change the aspect (namely width/height) to a random value " - "in ``[1 - max_aspect_ratio, 1 + max_aspect_ratio]``"); + "in ``[min_aspect_ratio, max_aspect_ratio]``"); DMLC_DECLARE_FIELD(max_shear_ratio).set_default(0.0f) .describe("Apply a shear transformation (namely ``(x,y)->(x+my,y)``) " "with ``m`` randomly chose from " "``[-max_shear_ratio, max_shear_ratio]``"); DMLC_DECLARE_FIELD(max_crop_size).set_default(-1) .describe("Crop both width and height into a random size in " - "``[min_crop_size, max_crop_size]``"); + "``[min_crop_size, max_crop_size].``" + "Ignored if ``random_resized_crop`` is True."); DMLC_DECLARE_FIELD(min_crop_size).set_default(-1) .describe("Crop both width and height into a random size in " - "``[min_crop_size, max_crop_size]``"); + "``[min_crop_size, max_crop_size].``" + "Ignored if ``random_resized_crop`` is True."); DMLC_DECLARE_FIELD(max_random_scale).set_default(1.0f) .describe("Resize into ``[width*s, height*s]`` with ``s`` randomly" - " chosen from ``[min_random_scale, max_random_scale]``"); + " chosen from ``[min_random_scale, max_random_scale]``. " + "Ignored if ``random_resized_crop`` is True."); DMLC_DECLARE_FIELD(min_random_scale).set_default(1.0f) .describe("Resize into ``[width*s, height*s]`` with ``s`` randomly" - " chosen from ``[min_random_scale, max_random_scale]``"); + " chosen from ``[min_random_scale, max_random_scale]``" + "Ignored if ``random_resized_crop`` is True."); + DMLC_DECLARE_FIELD(max_random_area).set_default(1.0f) + .describe("Change the area (namely width * height) to a random value " + "in ``[min_random_area, max_random_area]``. " + "Ignored if ``random_resized_crop`` is False."); + DMLC_DECLARE_FIELD(min_random_area).set_default(1.0f) + .describe("Change the area (namely width * height) to a random value " + "in ``[min_random_area, max_random_area]``. " + "Ignored if ``random_resized_crop`` is False."); DMLC_DECLARE_FIELD(max_img_size).set_default(1e10f) .describe("Set the maximal width and height after all resize and" " rotate argumentation are applied"); DMLC_DECLARE_FIELD(min_img_size).set_default(0.0f) .describe("Set the minimal width and height after all resize and" " rotate argumentation are applied"); + DMLC_DECLARE_FIELD(brightness).set_default(0.0f) + .describe("Add a random value in ``[-brightness, brightness]`` to " + "the brightness of image."); + DMLC_DECLARE_FIELD(contrast).set_default(0.0f) + .describe("Add a random value in ``[-contrast, contrast]`` to " + "the contrast of image."); + DMLC_DECLARE_FIELD(saturation).set_default(0.0f) + .describe("Add a random value in ``[-saturation, saturation]`` to " + "the saturation of image."); + DMLC_DECLARE_FIELD(pca_noise).set_default(0.0f) + .describe("Add PCA based noise to the image."); DMLC_DECLARE_FIELD(random_h).set_default(0) .describe("Add a random value in ``[-random_h, random_h]`` to " "the H channel in HSL color space."); @@ -197,6 +245,18 @@ class DefaultImageAugmenter : public ImageAugmenter { cv::Mat Process(const cv::Mat &src, std::vector *label, common::RANDOM_ENGINE *prnd) override { using mshadow::index_t; + bool is_cropped = false; + + float max_aspect_ratio = 1.0f; + float min_aspect_ratio = 1.0f; + if (param_.min_aspect_ratio.has_value()) { + max_aspect_ratio = param_.max_aspect_ratio; + min_aspect_ratio = param_.min_aspect_ratio.value(); + } else { + max_aspect_ratio = 1 + param_.max_aspect_ratio; + min_aspect_ratio = 1 - param_.max_aspect_ratio; + } + cv::Mat res; if (param_.resize != -1) { int new_height, new_width; @@ -220,8 +280,9 @@ class DefaultImageAugmenter : public ImageAugmenter { // normal augmentation by affine transformation. if (param_.max_rotate_angle > 0 || param_.max_shear_ratio > 0.0f - || param_.rotate > 0 || rotate_list_.size() > 0 || param_.max_random_scale != 1.0 - || param_.min_random_scale != 1.0 || param_.max_aspect_ratio != 0.0f + || param_.rotate > 0 || rotate_list_.size() > 0 + || param_.max_random_scale != 1.0f || param_.min_random_scale != 1.0 + || min_aspect_ratio != 1.0f || max_aspect_ratio != 1.0f || param_.max_img_size != 1e10f || param_.min_img_size != 0.0f) { std::uniform_real_distribution rand_uniform(0, 1); // shear @@ -236,11 +297,17 @@ class DefaultImageAugmenter : public ImageAugmenter { float a = cos(angle / 180.0 * M_PI); float b = sin(angle / 180.0 * M_PI); // scale - float scale = rand_uniform(*prnd) * - (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale; + float scale = 1.0f; + if (!param_.random_resized_crop) { + scale = rand_uniform(*prnd) * + (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale; + } // aspect ratio - float ratio = rand_uniform(*prnd) * - param_.max_aspect_ratio * 2 - param_.max_aspect_ratio + 1; + float ratio = 1.0f; + if (!param_.random_resized_crop) { + ratio = rand_uniform(*prnd) * + (max_aspect_ratio - min_aspect_ratio) + min_aspect_ratio; + } float hs = 2 * scale / (1 + ratio); float ws = ratio * hs; // new width and height @@ -276,8 +343,59 @@ class DefaultImageAugmenter : public ImageAugmenter { cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); } - // crop logic - if (param_.max_crop_size != -1 || param_.min_crop_size != -1) { + if (param_.random_resized_crop) { + // random resize crop + CHECK(param_.min_random_scale == 1.0f && + param_.max_random_scale == 1.0f && + param_.min_crop_size == -1 && + param_.max_crop_size == -1 && + !param_.rand_crop) << + "\nSetting random_resized_crop to true conflicts with " + "min_random_scale, max_random_scale, " + "min_crop_size, max_crop_size, " + "and rand_crop."; + + if (param_.max_random_area != 1.0f || param_.min_random_area != 1.0f + || max_aspect_ratio != 1.0f || min_aspect_ratio != 1.0f) { + CHECK(min_aspect_ratio > 0.0f); + CHECK(param_.min_random_area <= param_.max_random_area); + CHECK(min_aspect_ratio <= max_aspect_ratio); + std::uniform_real_distribution rand_uniform_area(param_.min_random_area, + param_.max_random_area); + std::uniform_real_distribution rand_uniform_ratio(min_aspect_ratio, + max_aspect_ratio); + std::uniform_real_distribution rand_uniform(0, 1); + float area = res.rows * res.cols; + for (int i = 0; i < 10; ++i) { + float rand_area = rand_uniform_area(*prnd); + float ratio = rand_uniform_ratio(*prnd); + float target_area = area * rand_area; + int y_area = std::round(std::sqrt(target_area / ratio)); + int x_area = std::round(std::sqrt(target_area * ratio)); + if (rand_uniform(*prnd) > 0.5) { + float temp_y_area = y_area; + y_area = x_area; + x_area = temp_y_area; + } + if (y_area <= res.rows && x_area <= res.cols) { + index_t rand_y_area = + std::uniform_int_distribution(0, res.rows - y_area)(*prnd); + index_t rand_x_area = + std::uniform_int_distribution(0, res.cols - x_area)(*prnd); + cv::Rect roi(rand_x_area, rand_y_area, x_area, y_area); + int interpolation_method = GetInterMethod(param_.inter_method, x_area, y_area, + param_.data_shape[2], + param_.data_shape[1], prnd); + cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1]), + 0, 0, interpolation_method); + is_cropped = true; + break; + } + } + } + } else if (!param_.random_resized_crop && + (param_.max_crop_size != -1 || param_.min_crop_size != -1)) { + // random_crop CHECK(res.cols >= param_.max_crop_size && res.rows >= \ param_.max_crop_size && param_.max_crop_size >= param_.min_crop_size) << "input image size smaller than max_crop_size"; @@ -296,7 +414,28 @@ class DefaultImageAugmenter : public ImageAugmenter { param_.data_shape[2], param_.data_shape[1], prnd); cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1]) , 0, 0, interpolation_method); - } else { + is_cropped = true; + } + + if (!is_cropped) { + // center crop + int interpolation_method = GetInterMethod(param_.inter_method, res.cols, res.rows, + param_.data_shape[2], + param_.data_shape[1], prnd); + if (res.rows < param_.data_shape[1]) { + index_t new_cols = static_cast(static_cast(param_.data_shape[1]) / + static_cast(res.rows) * + static_cast(res.cols)); + cv::resize(res, res, cv::Size(new_cols, param_.data_shape[1]), + 0, 0, interpolation_method); + } + if (res.cols < param_.data_shape[2]) { + index_t new_rows = static_cast(static_cast(param_.data_shape[2]) / + static_cast(res.cols) * + static_cast(res.rows)); + cv::resize(res, res, cv::Size(param_.data_shape[2], new_rows), + 0, 0, interpolation_method); + } CHECK(static_cast(res.rows) >= param_.data_shape[1] && static_cast(res.cols) >= param_.data_shape[2]) << "input image size smaller than input shape"; @@ -312,13 +451,48 @@ class DefaultImageAugmenter : public ImageAugmenter { res = res(roi); } + // color jitter + if (param_.brightness > 0.0f || param_.contrast > 0.0f || param_.saturation > 0.0f) { + std::uniform_real_distribution rand_uniform(0, 1); + float alpha_b = 1.0 + std::uniform_real_distribution(-param_.brightness, + param_.brightness)(*prnd); + float alpha_c = 1.0 + std::uniform_real_distribution(-param_.contrast, + param_.contrast)(*prnd); + float alpha_s = 1.0 + std::uniform_real_distribution(-param_.saturation, + param_.saturation)(*prnd); + int rand_order[3] = {0, 1, 2}; + std::random_shuffle(std::begin(rand_order), std::end(rand_order)); + for (int i = 0; i < 3; ++i) { + if (rand_order[i] == 0) { + // brightness + res.convertTo(res, -1, alpha_b, 0); + } + if (rand_order[i] == 1) { + // contrast + cvtColor(res, temp_, CV_RGB2GRAY); + float gray_mean = cv::mean(temp_)[0]; + res.convertTo(res, -1, alpha_c, (1 - alpha_c) * gray_mean); + } + if (rand_order[i] == 2) { + // saturation + cvtColor(res, temp_, CV_RGB2GRAY); + cvtColor(temp_, temp_, CV_GRAY2BGR); + cv::addWeighted(res, alpha_s, temp_, 1 - alpha_s, 0.0, res); + } + } + } + // color space augmentation if (param_.random_h != 0 || param_.random_s != 0 || param_.random_l != 0) { std::uniform_real_distribution rand_uniform(0, 1); cvtColor(res, res, CV_BGR2HLS); - int h = rand_uniform(*prnd) * param_.random_h * 2 - param_.random_h; - int s = rand_uniform(*prnd) * param_.random_s * 2 - param_.random_s; - int l = rand_uniform(*prnd) * param_.random_l * 2 - param_.random_l; + // use an approximation of gaussian distribution to reduce extreme value + float rh = rand_uniform(*prnd); rh += 4 * rand_uniform(*prnd); rh = rh / 5; + float rs = rand_uniform(*prnd); rs += 4 * rand_uniform(*prnd); rs = rs / 5; + float rl = rand_uniform(*prnd); rl += 4 * rand_uniform(*prnd); rl = rl / 5; + int h = rh * param_.random_h * 2 - param_.random_h; + int s = rs * param_.random_s * 2 - param_.random_s; + int l = rl * param_.random_l * 2 - param_.random_l; int temp[3] = {h, l, s}; int limit[3] = {180, 255, 255}; for (int i = 0; i < res.rows; ++i) { @@ -333,14 +507,45 @@ class DefaultImageAugmenter : public ImageAugmenter { } cvtColor(res, res, CV_HLS2BGR); } + + // pca noise + if (param_.pca_noise > 0.0f) { + std::normal_distribution rand_normal(0, param_.pca_noise); + float pca_alpha_r = rand_normal(*prnd); + float pca_alpha_g = rand_normal(*prnd); + float pca_alpha_b = rand_normal(*prnd); + float pca_r = eigvec[0][0] * pca_alpha_r + eigvec[0][1] * pca_alpha_g + + eigvec[0][2] * pca_alpha_b; + float pca_g = eigvec[1][0] * pca_alpha_r + eigvec[1][1] * pca_alpha_g + + eigvec[1][2] * pca_alpha_b; + float pca_b = eigvec[2][0] * pca_alpha_r + eigvec[2][1] * pca_alpha_g + + eigvec[2][2] * pca_alpha_b; + float pca[3] = { pca_b, pca_g, pca_r }; + for (int i = 0; i < res.rows; ++i) { + for (int j = 0; j < res.cols; ++j) { + for (int k = 0; k < 3; ++k) { + int vp = res.at(i, j)[k]; + vp += pca[k]; + vp = std::max(0, std::min(255, vp)); + res.at(i, j)[k] = vp; + } + } + } + } return res; } + private: // temporal space cv::Mat temp_; // rotation param cv::Mat rotateM_; + // eigval and eigvec for adding pca noise + // store eigval * eigvec as eigvec + float eigvec[3][3] = { { 55.46f * -0.5675f, 4.794f * 0.7192f, 1.148f * 0.4009f }, + { 55.46f * -0.5808f, 4.794f * -0.0045f, 1.148f * -0.8140f }, + { 55.46f * -0.5836f, 4.794f * -0.6948f, 1.148f * 0.4203f } }; // parameters DefaultImageAugmentParam param_; /*! \brief list of possible rotate angle */ diff --git a/tests/python/train/test_resnet_aug.py b/tests/python/train/test_resnet_aug.py new file mode 100644 index 000000000000..62c531bb6374 --- /dev/null +++ b/tests/python/train/test_resnet_aug.py @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +import sys +sys.path.insert(0, '../../python') +import mxnet as mx +import numpy as np +import os, pickle, gzip +import logging +from mxnet.test_utils import get_cifar10 + +batch_size = 128 + +# small mlp network +def get_net(): + data = mx.symbol.Variable('data') + float_data = mx.symbol.Cast(data=data, dtype="float32") + fc1 = mx.symbol.FullyConnected(float_data, name='fc1', num_hidden=128) + act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") + fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) + act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") + fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) + softmax = mx.symbol.SoftmaxOutput(fc3, name="softmax") + return softmax + +# check data +get_cifar10() + +def get_iterator(kv): + data_shape = (3, 28, 28) + + train = mx.io.ImageRecordIter( + path_imgrec = "data/cifar/train.rec", + mean_img = "data/cifar/mean.bin", + data_shape = data_shape, + batch_size = batch_size, + random_resized_crop = True, + min_aspect_ratio = 0.75, + max_aspect_ratio = 1.33, + min_random_area = 0.08, + max_random_area = 1, + brightness = 0.4, + contrast = 0.4, + saturation = 0.4, + pca_noise = 0.1, + rand_mirror = True, + num_parts = kv.num_workers, + part_index = kv.rank) + train = mx.io.PrefetchingIter(train) + + val = mx.io.ImageRecordIter( + path_imgrec = "data/cifar/test.rec", + mean_img = "data/cifar/mean.bin", + rand_crop = False, + rand_mirror = False, + data_shape = data_shape, + batch_size = batch_size, + num_parts = kv.num_workers, + part_index = kv.rank) + + return (train, val) + +num_epoch = 1 + +def run_cifar10(train, val, use_module): + train.reset() + val.reset() + devs = [mx.cpu(0)] + net = get_net() + mod = mx.mod.Module(net, context=devs) + optim_args = {'learning_rate': 0.001, 'wd': 0.00001, 'momentum': 0.9} + eval_metrics = ['accuracy'] + if use_module: + executor = mx.mod.Module(net, context=devs) + executor.fit( + train, + eval_data=val, + optimizer_params=optim_args, + eval_metric=eval_metrics, + num_epoch=num_epoch, + arg_params=None, + aux_params=None, + begin_epoch=0, + batch_end_callback=mx.callback.Speedometer(batch_size, 50), + epoch_end_callback=None) + else: + executor = mx.model.FeedForward.create( + net, + train, + ctx=devs, + eval_data=val, + eval_metric=eval_metrics, + num_epoch=num_epoch, + arg_params=None, + aux_params=None, + begin_epoch=0, + batch_end_callback=mx.callback.Speedometer(batch_size, 50), + epoch_end_callback=None, + **optim_args) + + ret = executor.score(val, eval_metrics) + if use_module: + ret = list(ret) + logging.info('final accuracy = %f', ret[0][1]) + assert (ret[0][1] > 0.08) + else: + logging.info('final accuracy = %f', ret[0]) + assert (ret[0] > 0.08) + +class CustomDataIter(mx.io.DataIter): + def __init__(self, data): + super(CustomDataIter, self).__init__() + self.data = data + self.batch_size = data.provide_data[0][1][0] + + # use legacy tuple + self.provide_data = [(n, s) for n, s in data.provide_data] + self.provide_label = [(n, s) for n, s in data.provide_label] + + def reset(self): + self.data.reset() + + def next(self): + return self.data.next() + + def iter_next(self): + return self.data.iter_next() + + def getdata(self): + return self.data.getdata() + + def getlabel(self): + return self.data.getlable() + + def getindex(self): + return self.data.getindex() + + def getpad(self): + return self.data.getpad() + +def test_cifar10(): + # print logging by default + logging.basicConfig(level=logging.DEBUG) + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + logging.getLogger('').addHandler(console) + kv = mx.kvstore.create("local") + # test float32 input + (train, val) = get_iterator(kv) + run_cifar10(train, val, use_module=False) + run_cifar10(train, val, use_module=True) + + # test legecay tuple in provide_data and provide_label + run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=False) + run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=True) + +if __name__ == "__main__": + test_cifar10()