-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
@mxnet-label-bot add [Metric, pr-awaiting-review] Thanks @ptrendx, I'll have a look |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Just a few comments.
Also, could you test the following:
- image classification example (train_mnist.py, train_imagenet.py)
- Speedometer's auto_reset=True and auto_reset=False
- tools/parse_log.py
python/mxnet/metric.py
Outdated
if "has_global_stats" in kwargs: | ||
self._has_global_stats = kwargs["has_global_stats"] | ||
else: | ||
self._has_global_stats = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._has_global_stats = kwargs.get("has_global_stats", False)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done (with pop to not keep the has_global_stats key in kwargs, to not screw with deserialization later.)
python/mxnet/metric.py
Outdated
self.metrics.reset_stats() | ||
else: | ||
self.sum_metric = self.metrics.fscore * self.metrics.total_examples | ||
self.sum_metric = fscore * self.metrics.total_examples | ||
self.global_sum_metric = fscore * self.metrics.total_examples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.sum_metric = self.global_sum_metric =
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/mxnet/metric.py
Outdated
self.num_inst = self.metrics.total_examples | ||
self.global_num_inst = self.metrics.total_examples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
python/mxnet/metric.py
Outdated
self.num_inst = self._metrics.total_examples | ||
self.global_num_inst = self._metrics.total_examples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
I tested train_imagenet.py, tools/parse_log.py and both auto_reset values for Speedometer. |
@ptrendx thanks for the PR. Do you mind elaborating a bit more how this PR avoids GIL/speeds up metric evaluation? |
It does not avoid GIL, I just do less work in Python -
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix and explanation. LGTM
Thanks for the contribution @ptrendx ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Few comments in line
else: | ||
return (self.name, self.global_sum_metric / self.global_num_inst) | ||
else: | ||
return self.get() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If user calls specifically global statistics and if it is not available, shouldn't we throw exception than silently return local? Same in other places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if that is possible - doing this would break fully custom metrics (other than subclasses of CustomMetric
class, for which I added support) that did not implement global stats.
for metric in self.metrics: | ||
metric.reset_local() | ||
except AttributeError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this required? When will reach here? Can we please document?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not added in this PR and I'm not sure myself why this is needed here (this function is basically a copy of the reset
function but calls reset_local
on children instead of reset
).
self.sum_metric += (pred_label == label).sum() | ||
num_correct = (pred_label == label).sum() | ||
self.sum_metric += num_correct | ||
self.global_sum_metric += num_correct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am sorry for trivial question, but, when will global metrics different than local metrics with this logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They will be different when you call reset_local
(which is called from Speedometer
for example) - this will reset the local versions of metrics while keeping the global versions intact (which was the point of this PR - to enable computing both per-batch and per-epoch statistics using only single computation).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That said, this comment made me think about this again and I found a bug in how I handle global statistics in F1
and MCC
metrics - fix and tests incoming. Thank you!
@@ -487,7 +569,7 @@ def update(self, labels, preds): | |||
|
|||
for label, pred_label in zip(labels, preds): | |||
assert(len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims' | |||
pred_label = numpy.argsort(pred_label.asnumpy().astype('float32'), axis=1) | |||
pred_label = numpy.argpartition(pred_label.asnumpy().astype('float32'), -self.top_k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice. I think it warrants a comment on why argpartition is used here (for its performance benefit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do.
python/mxnet/metric.py
Outdated
Tuple of (str, float) | ||
Representing name of the metric and evaluation result. | ||
""" | ||
num = self.global_num_inst if self.global_num_inst > 0 else float('nan') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return float('nan')?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this change got here by accident - I will revert it. Good catch, thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I added tests, it turned out that changing this is actually necessary (I did change it slightly differently though, to match the get
function from a base class) to avoid floating point division by 0 exception.
added test for global stats
I made fixes and added test for all metrics. @sandeep-krishnamurthy please review again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM.
Description
Currently metrics are mostly evaluated on CPU using NumPy. Due to Python GIL, they are evaluated in single thread, sequentially, which may become a problem once number of used GPUs is large enough.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
CC @vandanavk for comments.