-
Notifications
You must be signed in to change notification settings - Fork 14
/
speedometer_reset.patch
32 lines (31 loc) · 1.37 KB
/
speedometer_reset.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py
index 396f5a1..544eab2 100644
--- a/python/mxnet/callback.py
+++ b/python/mxnet/callback.py
@@ -96,13 +96,16 @@ class Speedometer(object):
frequent: int
How many batches between calculations.
Defaults to calculating & logging every 50 batches.
+ auto_reset : bool
+ Reset the metric after each log.
"""
- def __init__(self, batch_size, frequent=50):
+ def __init__(self, batch_size, frequent=50, auto_reset=False):
self.batch_size = batch_size
self.frequent = frequent
self.init = False
self.tic = 0
self.last_count = 0
+ self.auto_reset = auto_reset
def __call__(self, param):
"""Callback to Show speed."""
@@ -116,7 +119,8 @@ class Speedometer(object):
speed = self.frequent * self.batch_size / (time.time() - self.tic)
if param.eval_metric is not None:
name_value = param.eval_metric.get_name_value()
- param.eval_metric.reset()
+ if self.auto_reset:
+ param.eval_metric.reset()
for name, value in name_value:
logging.info('Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-%s=%f',
param.epoch, count, speed, name, value)