Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Update LoggingHandler to support logging per interval #16922

Merged
merged 10 commits into from
Dec 7, 2019
46 changes: 22 additions & 24 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,29 +227,22 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat

Parameters
----------
verbose : int, default LOG_PER_EPOCH
Limit the granularity of metrics displayed during training process.
verbose=LOG_PER_EPOCH: display metrics every epoch
verbose=LOG_PER_BATCH: display metrics every batch
log_interval: int or str, default 'epoch'
liuzh47 marked this conversation as resolved.
Show resolved Hide resolved
Logging interval during training.
log_interval='epoch': display metrics every epoch
log_interval=integer k: display metrics every interval of k batches
train_metrics : list of EvalMetrics
Training metrics to be logged, logged at batch end, epoch end, train end.
val_metrics : list of EvalMetrics
Validation metrics to be logged, logged at epoch end, train end.
"""

LOG_PER_EPOCH = 1
LOG_PER_BATCH = 2

def __init__(self, verbose=LOG_PER_EPOCH,
def __init__(self, log_interval='epoch',
train_metrics=None,
val_metrics=None):
super(LoggingHandler, self).__init__()
if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
raise ValueError("verbose level must be either LOG_PER_EPOCH or "
"LOG_PER_BATCH, received %s. "
"E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
% verbose)
self.verbose = verbose
if not isinstance(log_interval, int) and log_interval != 'epoch':
raise ValueError("log_interval must be either an integer or string 'epoch'")
self.train_metrics = _check_metrics(train_metrics)
self.val_metrics = _check_metrics(val_metrics)
self.batch_index = 0
Expand All @@ -258,6 +251,7 @@ def __init__(self, verbose=LOG_PER_EPOCH,
# logging handler need to be called at last to make sure all states are updated
# it will also shut down logging at train end
self.priority = np.Inf
self.log_interval = log_interval

def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
Expand All @@ -275,6 +269,7 @@ def train_begin(self, estimator, *args, **kwargs):
self.current_epoch = 0
self.batch_index = 0
self.processed_samples = 0
self.log_interval_time = 0

def train_end(self, estimator, *args, **kwargs):
train_time = time.time() - self.train_start
Expand All @@ -286,31 +281,34 @@ def train_end(self, estimator, *args, **kwargs):
estimator.logger.info(msg.rstrip(', '))

def batch_begin(self, estimator, *args, **kwargs):
if self.verbose == self.LOG_PER_BATCH:
if isinstance(self.log_interval, int):
self.batch_start = time.time()

def batch_end(self, estimator, *args, **kwargs):
if self.verbose == self.LOG_PER_BATCH:
if isinstance(self.log_interval, int):
batch_time = time.time() - self.batch_start
msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index)
self.processed_samples += kwargs['batch'][0].shape[0]
msg += '[Samples %s] ' % (self.processed_samples)
msg += 'time/batch: %.3fs ' % batch_time
for metric in self.train_metrics:
# only log current training loss & metric after each batch
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
self.log_interval_time += batch_time
if self.batch_index % self.log_interval == 0:
msg += 'time/interval: %.3fs ' % self.log_interval_time
self.log_interval_time = 0
for metric in self.train_metrics:
# only log current training loss & metric after each interval
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
self.batch_index += 1

def epoch_begin(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_PER_EPOCH:
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
self.epoch_start = time.time()
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)

def epoch_end(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_PER_EPOCH:
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
epoch_time = time.time() - self.epoch_start
msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
for monitor in self.train_metrics + self.val_metrics:
Expand Down
72 changes: 68 additions & 4 deletions tests/python/unittest/test_gluon_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

import os
import logging
import sys
import re

import mxnet as mx
from common import TemporaryDirectory
from mxnet import nd
from mxnet.gluon import nn, loss
from mxnet.gluon.contrib.estimator import estimator, event_handler

from mxnet.gluon.contrib.estimator.event_handler import LoggingHandler
try:
from StringIO import StringIO
except ImportError:
from io import StringIO

def _get_test_network(net=nn.Sequential()):
net.add(nn.Dense(128, activation='relu', flatten=False),
Expand All @@ -32,9 +38,9 @@ def _get_test_network(net=nn.Sequential()):
return net


def _get_test_data():
data = nd.ones((32, 100))
label = nd.zeros((32, 1))
def _get_test_data(in_size=32):
data = nd.ones((in_size, 100))
label = nd.zeros((in_size, 1))
data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
return mx.gluon.data.DataLoader(data_arr, batch_size=8)

Expand Down Expand Up @@ -200,3 +206,61 @@ def epoch_end(self, estimator, *args, **kwargs):
est.fit(test_data, event_handlers=[custom_handler], epochs=10)
assert custom_handler.num_batch == 5 * 4
assert custom_handler.num_epoch == 5

def test_logging_interval():
''' test different options for logging handler '''
''' test case #1: log interval is 1 '''
batch_size = 8
data_size = 100
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
log_interval = 1
net = _get_test_network()
dataloader = _get_test_data(in_size=data_size)
num_epochs = 1
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
est = estimator.Estimator(net=net,
loss=ce_loss,
metrics=acc)

est.fit(train_data=dataloader,
epochs=num_epochs,
event_handlers=[logging])

sys.stdout = old_stdout
log_info_list = mystdout.getvalue().splitlines()
info_len = 0
for info in log_info_list:
match = re.match(
'(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' +
' training accuracy: \d+.\d+)', info)
if match:
info_len += 1

assert(info_len == int(data_size/batch_size/log_interval) + 1)
''' test case #2: log interval is 5 '''
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
acc = mx.metric.Accuracy()
log_interval = 5
logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
est = estimator.Estimator(net=net,
loss=ce_loss,
metrics=acc)
est.fit(train_data=dataloader,
epochs=num_epochs,
event_handlers=[logging])
sys.stdout = old_stdout
log_info_list = mystdout.getvalue().splitlines()
info_len = 0
for info in log_info_list:
match = re.match(
'(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' +
' training accuracy: \d+.\d+)', info)
if match:
info_len += 1

assert(info_len == int(data_size/batch_size/log_interval) + 1)