Skip to content

Commit

Permalink
[SSD/PyT] Improved logging
Browse files Browse the repository at this point in the history
  • Loading branch information
shakandrew authored and nv-kkudrynski committed Aug 27, 2021
1 parent 01bbec9 commit d6cd6b8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
11 changes: 8 additions & 3 deletions PyTorch/Detection/SSD/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def make_parser():
parser.add_argument('--num-workers', type=int, default=4)
parser.add_argument('--amp', action='store_true',
help='Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.')
parser.add_argument('--log-interval', type=int, default=20,
help='Logging interval.')
parser.add_argument('--json-summary', type=str, default=None,
help='If provided, the json summary will be written to'
'the specified file.')
Expand Down Expand Up @@ -273,15 +275,18 @@ def log_params(logger, args):

if args.mode == 'benchmark-training':
train_loop_func = benchmark_train_loop
logger = BenchLogger('Training benchmark', json_output=args.json_summary)
logger = BenchLogger('Training benchmark', log_interval=args.log_interval,
json_output=args.json_summary)
args.epochs = 1
elif args.mode == 'benchmark-inference':
train_loop_func = benchmark_inference_loop
logger = BenchLogger('Inference benchmark', json_output=args.json_summary)
logger = BenchLogger('Inference benchmark', log_interval=args.log_interval,
json_output=args.json_summary)
args.epochs = 1
else:
train_loop_func = train_loop
logger = Logger('Training logger', print_freq=1, json_output=args.json_summary)
logger = Logger('Training logger', log_interval=args.log_interval,
json_output=args.json_summary)

log_params(logger, args)

Expand Down
6 changes: 3 additions & 3 deletions PyTorch/Detection/SSD/src/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def update_epoch(self, epoch):


class Logger:
def __init__(self, name, json_output=None, print_freq=20):
def __init__(self, name, json_output=None, log_interval=20):
self.name = name
self.train_loss_logger = IterationAverageMeter("Training loss")
self.train_epoch_time_logger = EpochMeter("Training 1 epoch time")
self.val_acc_logger = EpochMeter("Validation accuracy")
self.print_freq = print_freq
self.log_interval = log_interval

backends = [ DLLogger.StdOutBackend(DLLogger.Verbosity.DEFAULT) ]
if json_output:
Expand Down Expand Up @@ -95,7 +95,7 @@ def log_summary(self):
def update_iter(self, epoch, iteration, loss):
self.train_iter = iteration
self.train_loss_logger.update_iter(loss)
if iteration % self.print_freq == 0:
if iteration % self.log_interval == 0:
self.log('loss', loss)

def update_epoch(self, epoch, acc):
Expand Down

0 comments on commit d6cd6b8

Please sign in to comment.