-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for patiently accommodating the last minute design change requests. I have a few comments would like you to know what you think and create a follow up PR if necessary.
losses = [] | ||
for loss in self.loss: | ||
losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)]) | ||
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the model had multiple loss functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multi loss will be supported in #14628, let's get the first version into master and iterate on that.
val_metrics=val_metrics)) | ||
event_handlers.append(LoggingHandler(train_metrics=train_metrics, | ||
val_metrics=val_metrics)) | ||
warnings.warn("No Event Handler specified, default %s are used. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you write this warning using the LoggingHandler's logger? so the user has one place to control the log levels and look for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! for now we can only do this for estimator and handlers, any other warning from mxnet and gluon still can't be controlled. tracked here: https://issues.apache.org/jira/browse/MXNET-1395
losses = [] | ||
for loss in self.loss: | ||
losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)]) | ||
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same thing, using only a single loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as above
multi loss will be supported in #14628, let's get the first version into master and iterate on that.
for metric in self.train_metrics: | ||
metric.reset() | ||
|
||
def batch_end(self, estimator, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to capture this for every batch by default. I think we should update once per epoch by default and let the user control.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
once batch end we lost that batch's label and pred
self.train_metrics = train_metrics or [] | ||
# order to be called among all callbacks | ||
# metrics need to be calculated before other callbacks can access them | ||
self.priority = -np.Inf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the priority should be exposed in the base class, otherwise the user who writes custom handlers has no clue about this and the order is based on this.
I am not sure how python resolves the ambiguity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mechanism is mainly for internal use. I m making sure metric and validation are called first and logging are called last. I'm trying to reduce the information user need to know, they can order their own event handlers in the list before passing to fit()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets call this out explicitly in the documentation.
self.epoch_period = epoch_period | ||
self.batch_period = batch_period | ||
self.val_metrics = val_metrics | ||
self.num_batches = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how will the user control batch_period
and num_batches
when you are using this when no handlers are specified. does he have to specify all the handlers to make this change?
do you think we should make this static, so user can independently update this, one drawback is if there are multiple ValidationHandlers used in the same process all of them get the same value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! if we provide default handler one by one (so user don't need to re-create all just to custom one of them). We need a mechanism to make sure all handlers has the reference of the same set of metric objects. or make handlers an attributes so they can be configured after default handlers been created. tracked https://issues.apache.org/jira/browse/MXNET-1396
file_location=None, | ||
verbose=LOG_VERBOSITY_PER_EPOCH, | ||
train_metrics=None, | ||
val_metrics=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to customize, given that prepare_loss_and_metrics happen after the estimator is created. If at all the order should be
e = Estimator()
e.prepare_loss_and_metrics()
lh = LoggingHandler(..., train_metrics=[e.train_metrics[0], e.train_metrics[1]], ..)
is this what you were thinking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes that's correct
self.logger.info(msg) | ||
self.batch_index += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be in the estimator itself, why should all handlers maintain this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not all handlers need the same set of these infomation, so they maintain whatever they want to use. This also prevents if one handler changed self.estimator.total_steps
wrongly, it will cause all other handlers to fail
'for example val_accuracy', self.monitor)) | ||
self.estimator.net.save_parameters(self.filepath) | ||
if np.isnan(monitor_value): | ||
warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logger, so the user can control.
@@ -191,18 +287,23 @@ class CheckpointHandler(EventHandler): | |||
|
|||
def __init__(self, | |||
filepath, | |||
monitor='val_accuracy', | |||
monitor=None, | |||
verbose=0, | |||
save_best_only=False, | |||
mode='auto', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you expand what different modes mean in the doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's explained in the doc string
* improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests
* improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests
* [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests * Fixed issue where the estimator was printing beyond the dataset size … (#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric * [MXNet-1340][Fit API]Update train stats (#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers * [MXNet-1375][Fit API]Added RNN integration test for fit() API (#14547) * Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code * [MXNet-1343][Fit API]Add CNN integration test for fit() API (#14405) * added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py * move estimator to contrib (#14633) * move to gluon contrib (#14635) * [Fit API] improve event handlers (#14685) * improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests * [MXNET-1396][Fit-API] Update default handler logic (#14765) * move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci * [Fit API] update estimator (#14849) * address comments * add comment * check available context * fix bug * change cpu check * [Fit-API] Adress PR comments (#14885) * address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
* improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests
* [MXNet-1334][Fit API]base class for estimator and eventhandler (apache#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests * Fixed issue where the estimator was printing beyond the dataset size … (apache#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric * [MXNet-1340][Fit API]Update train stats (apache#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers * [MXNet-1375][Fit API]Added RNN integration test for fit() API (apache#14547) * Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code * [MXNet-1343][Fit API]Add CNN integration test for fit() API (apache#14405) * added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (apache#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py * move estimator to contrib (apache#14633) * move to gluon contrib (apache#14635) * [Fit API] improve event handlers (apache#14685) * improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests * [MXNET-1396][Fit-API] Update default handler logic (apache#14765) * move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci * [Fit API] update estimator (apache#14849) * address comments * add comment * check available context * fix bug * change cpu check * [Fit-API] Adress PR comments (apache#14885) * address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
* improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests
Description
Making the follwing on evetn handlers based on the design here:
https://cwiki.apache.org/confluence/display/MXNET/Callback+Design+for+Fit+Loop
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments