Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Fit API] update estimator #14849

Merged
merged 5 commits into from
May 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,36 @@ def _check_metrics(self, metrics):
return metrics

def _check_context(self, context):
# handle context
if isinstance(context, Context):
context = [context]
elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
context = context
elif not context:
if num_gpus() > 0:
# infer available context
gpus = num_gpus()
available_gpus = [gpu(i) for i in range(gpus)]

if context:
# check context values, only accept Context or a list of Context
if isinstance(context, Context):
context = [context]
elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
context = context
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check for the GPU device index too? also, try querying num_gpus() only once.

Copy link
Member Author

@roywei roywei Apr 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha thanks! I m now asserting context must be in available_context which is [cpu(), gpu(0), ...., gpu(num_gpus-1)]. added unit test

Copy link
Member

@szha szha May 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu(65536) is actually valid in mxnet regardless of the number of physical CPUs, whereas cpu() refers to cpu(0). For GPU, the check is fine, but for CPU, we might need to do a more general check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

available_context = [mx.gpu(i) for i in range(num_gpus)]
assert ctx  in available_context or str(ctx).startswith('cpu')

else:
raise ValueError("context must be a Context or a list of Context, "
"for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], "
"refer to mxnet.Context:{}".format(context))
for ctx in context:
assert ctx in available_gpus or str(ctx).startswith('cpu'), \
"%s is not available, please make sure " \
"your context is in one of: mx.cpu(), %s" % \
(ctx, ", ".join([str(ctx) for ctx in available_gpus]))
else:
# provide default context
if gpus > 0:
# only use 1 GPU by default
if num_gpus() > 1:
if gpus > 1:
warnings.warn("You have multiple GPUs, gpu(0) will be used by default."
"To utilize all your GPUs, specify context as a list of gpus, "
"e.g. context=[mx.gpu(0), mx.gpu(1)] ")
context = [gpu(0)]
else:
context = [cpu()]
else:
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))
return context

def _initialize(self, initializer):
Expand Down Expand Up @@ -167,7 +179,8 @@ def prepare_loss_and_metrics(self):
self.train_metrics = [Accuracy()]
self.val_metrics = []
for loss in self.loss:
self.train_metrics.append(Loss(''.join([i for i in loss.name if not i.isdigit()])))
# remove trailing numbers from loss name to avoid confusion
self.train_metrics.append(Loss(loss.name.rstrip('1234567890')))
for metric in self.train_metrics:
val_metric = copy.deepcopy(metric)
metric.name = "Train " + metric.name
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ def test_context():
metrics=metrics,
context='cpu')

with assert_raises(AssertionError):
est = Estimator(net=net,
loss=loss,
metrics=metrics,
context=[mx.gpu(0), mx.gpu(100)])


def test_categorize_handlers():
class CustomHandler1(TrainBegin):
Expand Down