Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Fit API] update estimator (#14849)
Browse files Browse the repository at this point in the history
* address comments

* add comment

* check available context

* fix bug

* change cpu check
  • Loading branch information
roywei authored and szha committed May 2, 2019
1 parent 0748b47 commit 33e8845
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
37 changes: 25 additions & 12 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,36 @@ def _check_metrics(self, metrics):
return metrics

def _check_context(self, context):
# handle context
if isinstance(context, Context):
context = [context]
elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
context = context
elif not context:
if num_gpus() > 0:
# infer available context
gpus = num_gpus()
available_gpus = [gpu(i) for i in range(gpus)]

if context:
# check context values, only accept Context or a list of Context
if isinstance(context, Context):
context = [context]
elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
context = context
else:
raise ValueError("context must be a Context or a list of Context, "
"for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], "
"refer to mxnet.Context:{}".format(context))
for ctx in context:
assert ctx in available_gpus or str(ctx).startswith('cpu'), \
"%s is not available, please make sure " \
"your context is in one of: mx.cpu(), %s" % \
(ctx, ", ".join([str(ctx) for ctx in available_gpus]))
else:
# provide default context
if gpus > 0:
# only use 1 GPU by default
if num_gpus() > 1:
if gpus > 1:
warnings.warn("You have multiple GPUs, gpu(0) will be used by default."
"To utilize all your GPUs, specify context as a list of gpus, "
"e.g. context=[mx.gpu(0), mx.gpu(1)] ")
context = [gpu(0)]
else:
context = [cpu()]
else:
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))
return context

def _initialize(self, initializer):
Expand Down Expand Up @@ -167,7 +179,8 @@ def prepare_loss_and_metrics(self):
self.train_metrics = [Accuracy()]
self.val_metrics = []
for loss in self.loss:
self.train_metrics.append(Loss(''.join([i for i in loss.name if not i.isdigit()])))
# remove trailing numbers from loss name to avoid confusion
self.train_metrics.append(Loss(loss.name.rstrip('1234567890')))
for metric in self.train_metrics:
val_metric = copy.deepcopy(metric)
metric.name = "Train " + metric.name
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ def test_context():
metrics=metrics,
context='cpu')

with assert_raises(AssertionError):
est = Estimator(net=net,
loss=loss,
metrics=metrics,
context=[mx.gpu(0), mx.gpu(100)])


def test_categorize_handlers():
class CustomHandler1(TrainBegin):
Expand Down

0 comments on commit 33e8845

Please sign in to comment.