Skip to content

Commit

Permalink
modified EvalHook for eval mode to print the correct iter number
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhen17 committed May 5, 2021
1 parent 53bd40a commit 22e89d7
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
5 changes: 4 additions & 1 deletion mmcv/engine/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):

if rank == 0:
batch_size = len(result)
for _ in range(batch_size * world_size):
batch_size_all = batch_size * world_size
if batch_size_all + prog_bar.completed > len(dataset):
batch_size_all = len(dataset) - prog_bar.completed
for _ in range(batch_size_all):
prog_bar.update()

# collect results from all ranks
Expand Down
2 changes: 2 additions & 0 deletions mmcv/runner/hooks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def _do_evaluate(self, runner):

from mmcv.engine import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)
Expand Down Expand Up @@ -368,6 +369,7 @@ def _do_evaluate(self, runner):
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)

if self.save_best:
Expand Down
15 changes: 11 additions & 4 deletions mmcv/runner/hooks/logger/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ def _log_info(self, log_dict, runner):
log_str += f'memory: {log_dict["memory"]}, '
else:
# val/test time
# by epoch: Epoch[val] [4]
# by iter: Iter[val] [100]
# here 1000 is the length of the val dataloader
# by epoch: Epoch[val] [4][1000]
# by iter: Iter[val] [1000]
if self.by_epoch:
log_str = f'Epoch({log_dict["mode"]}) ' \
f'[{log_dict["epoch"]}]\t'
f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
else:
log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'

Expand Down Expand Up @@ -144,10 +145,16 @@ def _round_float(self, items):
return items

def log(self, runner):
if 'eval_iter_num' in runner.log_buffer.output:
# this doesn't modify runner.iter and is regardless of by_epoch
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
else:
cur_iter = self.get_iter(runner, inner_iter=True)

log_dict = OrderedDict(
mode=self.get_mode(runner),
epoch=self.get_epoch(runner),
iter=self.get_iter(runner, inner_iter=True))
iter=cur_iter)

# only record lr of the first param group
cur_lr = runner.current_lr()
Expand Down

0 comments on commit 22e89d7

Please sign in to comment.