-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Gluon image-classification example improvement #9633
Conversation
example/gluon/data.py
Outdated
@@ -49,50 +54,60 @@ def get_cifar10_iterator(batch_size, data_shape, resize=-1, num_parts=1, part_in | |||
|
|||
return train, val | |||
|
|||
class DataloaderIter(mx.io.DataIter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reusable for mx.io.
help='validation record file to use, required for imagenet.') | ||
help='dataset to use. options are mnist, cifar10, imagenet and dummy.') | ||
parser.add_argument('--data', type=str, default='', | ||
help='training directory of imagenet images, contains train/val subdirs.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure that the name of the option reflects that it's for a directory.
parser.add_argument('--dtype', default='float32', type=str, | ||
help='data type, float32 or float16 if applicable') | ||
parser.add_argument('--save-frequency', default=10, type=int, | ||
help='model save frequent, best model will always be saved') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the unit? batch?
parser.add_argument('--profile', action='store_true', | ||
help='Option to turn on memory profiling for front-end, '\ | ||
'and prints out the memory usage by python function at the end.') | ||
parser.add_argument('--top-k', type=int, default=0, help='add top-k metric if > 1') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are many options already, which could affect its usability. should we just hard-code this?
if isinstance(ctx, mx.Context): | ||
ctx = [ctx] | ||
net.initialize(mx.init.Xavier(magnitude=2), ctx=ctx) | ||
def train(opt, ctx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you intend to add metric argument here too? this way train
only involves using these components, which can make the code look simpler.
name, val_acc = test(ctx, metric, val_data) | ||
val_msg = ','.join(['%s=%f'%(n, a) for n, a in zip(as_list(name), as_list(val_acc))]) | ||
logger.info('[Epoch %d] validation: %s'%(epoch, val_msg)) | ||
top1 = val_acc[0] if isinstance(val_acc, list) else val_acc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be refactored
best_acc = top1 | ||
fname = os.path.join(opt.prefix, '%s_best.params' % (opt.model)) | ||
net.save_params(fname) | ||
logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider extracting these into a checkpoint function, so that train
function contains only the high-level, readable code.
@szha Too many lines for mx.io module, suggestions? |
maybe move it to mx.contrib.io first? |
18bba06
to
86099af
Compare
* backup * backup * finish * fix multiple * fix * fix * fix padding * add more tests * fix expanduser
* backup * backup * finish * fix multiple * fix * fix * fix padding * add more tests * fix expanduser
* backup * backup * finish * fix multiple * fix * fix * fix padding * add more tests * fix expanduser
Description
Checklist
Essentials
make lint
)Changes
Comments