From 6777ac8a9457161b446ab799d3d2e408b662a957 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 30 Apr 2019 11:47:43 -0700 Subject: [PATCH 1/5] address comments --- .../gluon/contrib/estimator/estimator.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index d30595a6efb5..c6e42c79fe0a 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -88,12 +88,17 @@ def _check_metrics(self, metrics): return metrics def _check_context(self, context): - # handle context - if isinstance(context, Context): - context = [context] - elif isinstance(context, list) and all([isinstance(c, Context) for c in context]): - context = context - elif not context: + if context: + # check context values, only accept Context or a list of Context + if isinstance(context, Context): + context = [context] + elif isinstance(context, list) and all([isinstance(c, Context) for c in context]): + context = context + else: + raise ValueError("context must be a Context or a list of Context, " + "refer to mxnet.Context:{}".format(context)) + else: + # provide default context if num_gpus() > 0: # only use 1 GPU by default if num_gpus() > 1: @@ -103,9 +108,6 @@ def _check_context(self, context): context = [gpu(0)] else: context = [cpu()] - else: - raise ValueError("context must be a Context or a list of Context, " - "refer to mxnet.Context:{}".format(context)) return context def _initialize(self, initializer): @@ -167,7 +169,7 @@ def prepare_loss_and_metrics(self): self.train_metrics = [Accuracy()] self.val_metrics = [] for loss in self.loss: - self.train_metrics.append(Loss(''.join([i for i in loss.name if not i.isdigit()]))) + self.train_metrics.append(Loss(loss.name.rstrip('1234567890'))) for metric in self.train_metrics: val_metric = copy.deepcopy(metric) metric.name = "Train " + metric.name From 8017685777bd258eb06eab6e0ca367340534c73f Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 30 Apr 2019 11:50:01 -0700 Subject: [PATCH 2/5] add comment --- python/mxnet/gluon/contrib/estimator/estimator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index c6e42c79fe0a..774974e9cbec 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -169,6 +169,7 @@ def prepare_loss_and_metrics(self): self.train_metrics = [Accuracy()] self.val_metrics = [] for loss in self.loss: + # remove trailing numbers from loss name to avoid confusion self.train_metrics.append(Loss(loss.name.rstrip('1234567890'))) for metric in self.train_metrics: val_metric = copy.deepcopy(metric) From 938dbb4ed15754d7f22d7b94241f28a02f53c183 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 30 Apr 2019 14:49:57 -0700 Subject: [PATCH 3/5] check available context --- .../mxnet/gluon/contrib/estimator/estimator.py | 16 ++++++++++++++-- tests/python/unittest/test_gluon_estimator.py | 6 ++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 774974e9cbec..a3a05208d5b3 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -88,6 +88,12 @@ def _check_metrics(self, metrics): return metrics def _check_context(self, context): + # infer available context + available_context = [cpu()] + gpus = num_gpus() + for i in range(gpus): + available_context.append(gpu(0)) + if context: # check context values, only accept Context or a list of Context if isinstance(context, Context): @@ -96,12 +102,18 @@ def _check_context(self, context): context = context else: raise ValueError("context must be a Context or a list of Context, " + "for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], " "refer to mxnet.Context:{}".format(context)) + for ctx in context: + assert ctx in available_context, \ + "%s is not available, please make sure " \ + "your context is in one of: %s" % \ + (ctx, " ,".join([str(ctx) for ctx in available_context])) else: # provide default context - if num_gpus() > 0: + if gpus > 0: # only use 1 GPU by default - if num_gpus() > 1: + if gpus > 1: warnings.warn("You have multiple GPUs, gpu(0) will be used by default." "To utilize all your GPUs, specify context as a list of gpus, " "e.g. context=[mx.gpu(0), mx.gpu(1)] ") diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 643214212e3a..b25baa255165 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -258,6 +258,12 @@ def test_context(): metrics=metrics, context='cpu') + with assert_raises(AssertionError): + est = Estimator(net=net, + loss=loss, + metrics=metrics, + context=[mx.gpu(0), mx.gpu(100)]) + def test_categorize_handlers(): class CustomHandler1(TrainBegin): From 615018bdd01747829fd6b660cfea28149c7f586d Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 30 Apr 2019 16:41:48 -0700 Subject: [PATCH 4/5] fix bug --- python/mxnet/gluon/contrib/estimator/estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index a3a05208d5b3..9fcba645f472 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -92,7 +92,7 @@ def _check_context(self, context): available_context = [cpu()] gpus = num_gpus() for i in range(gpus): - available_context.append(gpu(0)) + available_context.append(gpu(i)) if context: # check context values, only accept Context or a list of Context From 8fb57b12e70043565359dc76149323c672a54ef3 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 30 Apr 2019 21:20:28 -0700 Subject: [PATCH 5/5] change cpu check --- python/mxnet/gluon/contrib/estimator/estimator.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 9fcba645f472..f43f17520654 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -89,10 +89,8 @@ def _check_metrics(self, metrics): def _check_context(self, context): # infer available context - available_context = [cpu()] gpus = num_gpus() - for i in range(gpus): - available_context.append(gpu(i)) + available_gpus = [gpu(i) for i in range(gpus)] if context: # check context values, only accept Context or a list of Context @@ -105,10 +103,10 @@ def _check_context(self, context): "for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], " "refer to mxnet.Context:{}".format(context)) for ctx in context: - assert ctx in available_context, \ + assert ctx in available_gpus or str(ctx).startswith('cpu'), \ "%s is not available, please make sure " \ - "your context is in one of: %s" % \ - (ctx, " ,".join([str(ctx) for ctx in available_context])) + "your context is in one of: mx.cpu(), %s" % \ + (ctx, ", ".join([str(ctx) for ctx in available_gpus])) else: # provide default context if gpus > 0: