Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Gluon image-classification example improvement #9633

Merged
merged 9 commits into from
Mar 3, 2018

Conversation

zhreshold
Copy link
Member

Description

  • Example gluon image-classification now able to train imagenet properly.
  • More optimizations such as float16 and many more.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@zhreshold zhreshold requested a review from szha as a code owner January 30, 2018 21:09
@@ -49,50 +54,60 @@ def get_cifar10_iterator(batch_size, data_shape, resize=-1, num_parts=1, part_in

return train, val

class DataloaderIter(mx.io.DataIter):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reusable for mx.io.

help='validation record file to use, required for imagenet.')
help='dataset to use. options are mnist, cifar10, imagenet and dummy.')
parser.add_argument('--data', type=str, default='',
help='training directory of imagenet images, contains train/val subdirs.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure that the name of the option reflects that it's for a directory.

parser.add_argument('--dtype', default='float32', type=str,
help='data type, float32 or float16 if applicable')
parser.add_argument('--save-frequency', default=10, type=int,
help='model save frequent, best model will always be saved')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the unit? batch?

parser.add_argument('--profile', action='store_true',
help='Option to turn on memory profiling for front-end, '\
'and prints out the memory usage by python function at the end.')
parser.add_argument('--top-k', type=int, default=0, help='add top-k metric if > 1')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are many options already, which could affect its usability. should we just hard-code this?

if isinstance(ctx, mx.Context):
ctx = [ctx]
net.initialize(mx.init.Xavier(magnitude=2), ctx=ctx)
def train(opt, ctx):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you intend to add metric argument here too? this way train only involves using these components, which can make the code look simpler.

name, val_acc = test(ctx, metric, val_data)
val_msg = ','.join(['%s=%f'%(n, a) for n, a in zip(as_list(name), as_list(val_acc))])
logger.info('[Epoch %d] validation: %s'%(epoch, val_msg))
top1 = val_acc[0] if isinstance(val_acc, list) else val_acc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be refactored

best_acc = top1
fname = os.path.join(opt.prefix, '%s_best.params' % (opt.model))
net.save_params(fname)
logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider extracting these into a checkpoint function, so that train function contains only the high-level, readable code.

@zhreshold
Copy link
Member Author

@szha Too many lines for mx.io module, suggestions?

@szha
Copy link
Member

szha commented Jan 31, 2018

maybe move it to mx.contrib.io first?

@szha szha merged commit 8780096 into apache:master Mar 3, 2018
jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request Mar 30, 2018
* backup

* backup

* finish

* fix multiple

* fix

* fix

* fix padding

* add more tests

* fix expanduser
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* backup

* backup

* finish

* fix multiple

* fix

* fix

* fix padding

* add more tests

* fix expanduser
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* backup

* backup

* finish

* fix multiple

* fix

* fix

* fix padding

* add more tests

* fix expanduser
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants