Skip to content

Commit

Permalink
[Feature] Support save last checkpoint (#853)
Browse files Browse the repository at this point in the history
* [Feature] Support save last checkpoint

* move to before_run, update doc

* move to after train

* add comments
  • Loading branch information
xvjiarui authored Feb 28, 2021
1 parent e5eaf2a commit 34b552b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
45 changes: 32 additions & 13 deletions mmcv/runner/hooks/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class CheckpointHook(Hook):
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be saved
regardless of interval.
sync_buffer (bool): Whether to synchronize buffers in different
gpus. Default: False.
"""
Expand All @@ -34,30 +36,41 @@ def __init__(self,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
save_last=True,
sync_buffer=False,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer

def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir

def after_train_epoch(self, runner):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
if not self.by_epoch:
return

runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 2. reach the last epoch of training
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)

@master_only
def _save_checkpoint(self, runner):
"""Save the current checkpoint and delete unwanted checkpoint."""
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
if runner.meta is not None:
Expand Down Expand Up @@ -91,11 +104,17 @@ def _save_checkpoint(self, runner):
break

def after_train_iter(self, runner):
if self.by_epoch or not self.every_n_iters(runner, self.interval):
if self.by_epoch:
return

runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
6 changes: 6 additions & 0 deletions mmcv/runner/hooks/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,9 @@ def every_n_iters(self, runner, n):

def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)

def is_last_epoch(self, runner):
return runner.epoch + 1 == runner._max_epochs

def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters

0 comments on commit 34b552b

Please sign in to comment.