Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix wrong iter number and progress number in the logging during val/test time #914

Merged
merged 4 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mmcv/engine/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):

if rank == 0:
batch_size = len(result)
for _ in range(batch_size * world_size):
batch_size_all = batch_size * world_size
if batch_size_all + prog_bar.completed > len(dataset):
batch_size_all = len(dataset) - prog_bar.completed
for _ in range(batch_size_all):
prog_bar.update()

# collect results from all ranks
Expand Down
2 changes: 2 additions & 0 deletions mmcv/runner/hooks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def _do_evaluate(self, runner):

from mmcv.engine import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)
Expand Down Expand Up @@ -368,6 +369,7 @@ def _do_evaluate(self, runner):
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)

if self.save_best:
Expand Down
12 changes: 11 additions & 1 deletion mmcv/runner/hooks/logger/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def _log_info(self, log_dict, runner):
if torch.cuda.is_available():
log_str += f'memory: {log_dict["memory"]}, '
else:
# val/test time
# here 1000 is the length of the val dataloader
# by epoch: Epoch[val] [4][1000]
# by iter: Iter[val] [1000]
if self.by_epoch:
log_str = f'Epoch({log_dict["mode"]}) ' \
f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
Expand Down Expand Up @@ -141,10 +145,16 @@ def _round_float(self, items):
return items

def log(self, runner):
if 'eval_iter_num' in runner.log_buffer.output:
# this doesn't modify runner.iter and is regardless of by_epoch
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
else:
cur_iter = self.get_iter(runner, inner_iter=True)

log_dict = OrderedDict(
mode=self.get_mode(runner),
epoch=self.get_epoch(runner),
iter=self.get_iter(runner, inner_iter=True))
iter=cur_iter)

# only record lr of the first param group
cur_lr = runner.current_lr()
Expand Down