Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Optimization of metric evaluation #13471

Merged
merged 9 commits into from
Dec 13, 2018
Merged

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Nov 30, 2018

Description

Currently metrics are mostly evaluated on CPU using NumPy. Due to Python GIL, they are evaluated in single thread, sequentially, which may become a problem once number of used GPUs is large enough.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Changed TopKAccuracy metric implementation from using numpy.argsort to numpy.argpartition, which from my experiments on NDArrays of shape (208,1000) and top_k=5 is ~4x faster
  • Added global statistics to metrics. We noticed the problem with metrics speed after PR [MXNET-698] Correct train-metric log to reflect epoch metric #12182 which introduced computing the metrics twice (once for giving the immediate values for accuracy and once for computing accuracy over the entire epoch). This is wasteful, since the calculations needed are exactly the same in this case. This PR introduces additional fields in EvalMetric class (global_sum_metric and global_num_inst), accessors for them (get_global and get_global_name_values) and function reset_local which performs reset only on non-global versions of statistics. It also modifies all current metrics to be able to use the global statistics. That way, while the code is backward compatible (one can still just use get/get_name_values/reset functions), it also eliminates overhead introduced by calculating statistics twice in fit function.

CC @vandanavk for comments.

@ptrendx ptrendx requested a review from szha as a code owner November 30, 2018 00:09
@vandanavk
Copy link
Contributor

@mxnet-label-bot add [Metric, pr-awaiting-review]

Thanks @ptrendx, I'll have a look

@marcoabreu marcoabreu added Metric pr-awaiting-review PR is waiting for code review labels Nov 30, 2018
Copy link
Contributor

@vandanavk vandanavk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Just a few comments.
Also, could you test the following:

  • image classification example (train_mnist.py, train_imagenet.py)
  • Speedometer's auto_reset=True and auto_reset=False
  • tools/parse_log.py

if "has_global_stats" in kwargs:
self._has_global_stats = kwargs["has_global_stats"]
else:
self._has_global_stats = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._has_global_stats = kwargs.get("has_global_stats", False) ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (with pop to not keep the has_global_stats key in kwargs, to not screw with deserialization later.)

self.metrics.reset_stats()
else:
self.sum_metric = self.metrics.fscore * self.metrics.total_examples
self.sum_metric = fscore * self.metrics.total_examples
self.global_sum_metric = fscore * self.metrics.total_examples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.sum_metric = self.global_sum_metric = ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.num_inst = self.metrics.total_examples
self.global_num_inst = self.metrics.total_examples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

self.num_inst = self._metrics.total_examples
self.global_num_inst = self._metrics.total_examples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@ptrendx
Copy link
Member Author

ptrendx commented Nov 30, 2018

I tested train_imagenet.py, tools/parse_log.py and both auto_reset values for Speedometer.

@eric-haibin-lin
Copy link
Member

@ptrendx thanks for the PR. Do you mind elaborating a bit more how this PR avoids GIL/speeds up metric evaluation?

@ptrendx
Copy link
Member Author

ptrendx commented Dec 2, 2018

It does not avoid GIL, I just do less work in Python -

  • topk evaluation using numpy argpartition vs argsort is faster (and you don't need to do full sort to get top k elements)
  • currently metrics are calculated twice during training to get both local and per-epoch results - this PR makes it so they are evaluated once, and the result is used for both per-batch and per epochs statistics.

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix and explanation. LGTM

@lupesko
Copy link
Contributor

lupesko commented Dec 5, 2018

Thanks for the contribution @ptrendx !
Adding @sandeep-krishnamurthy to help review and merge.

Copy link
Contributor

@sandeep-krishnamurthy sandeep-krishnamurthy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Few comments in line

else:
return (self.name, self.global_sum_metric / self.global_num_inst)
else:
return self.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user calls specifically global statistics and if it is not available, shouldn't we throw exception than silently return local? Same in other places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if that is possible - doing this would break fully custom metrics (other than subclasses of CustomMetric class, for which I added support) that did not implement global stats.

for metric in self.metrics:
metric.reset_local()
except AttributeError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this required? When will reach here? Can we please document?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not added in this PR and I'm not sure myself why this is needed here (this function is basically a copy of the reset function but calls reset_local on children instead of reset).

self.sum_metric += (pred_label == label).sum()
num_correct = (pred_label == label).sum()
self.sum_metric += num_correct
self.global_sum_metric += num_correct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry for trivial question, but, when will global metrics different than local metrics with this logic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They will be different when you call reset_local (which is called from Speedometer for example) - this will reset the local versions of metrics while keeping the global versions intact (which was the point of this PR - to enable computing both per-batch and per-epoch statistics using only single computation).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, this comment made me think about this again and I found a bug in how I handle global statistics in F1 and MCC metrics - fix and tests incoming. Thank you!

@@ -487,7 +569,7 @@ def update(self, labels, preds):

for label, pred_label in zip(labels, preds):
assert(len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims'
pred_label = numpy.argsort(pred_label.asnumpy().astype('float32'), axis=1)
pred_label = numpy.argpartition(pred_label.asnumpy().astype('float32'), -self.top_k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice. I think it warrants a comment on why argpartition is used here (for its performance benefit)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

Tuple of (str, float)
Representing name of the metric and evaluation result.
"""
num = self.global_num_inst if self.global_num_inst > 0 else float('nan')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return float('nan')?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this change got here by accident - I will revert it. Good catch, thank you!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I added tests, it turned out that changing this is actually necessary (I did change it slightly differently though, to match the get function from a base class) to avoid floating point division by 0 exception.

@ptrendx
Copy link
Member Author

ptrendx commented Dec 12, 2018

I made fixes and added test for all metrics. @sandeep-krishnamurthy please review again.

Copy link
Contributor

@sandeep-krishnamurthy sandeep-krishnamurthy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Metric pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants