Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add auto resume #1172

Merged
merged 7 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mmseg/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from mmseg.core import DistEvalHook, EvalHook
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import get_root_logger
from mmseg.utils import find_latest_checkpoint, get_root_logger


def init_random_seed(seed=None, device='cuda'):
Expand Down Expand Up @@ -160,6 +160,10 @@ def train_segmentor(model,
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)

if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
3 changes: 2 additions & 1 deletion mmseg/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .logger import get_root_logger
from .misc import find_latest_checkpoint

__all__ = ['get_root_logger', 'collect_env']
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint']
41 changes: 41 additions & 0 deletions mmseg/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings


def find_latest_checkpoint(path, suffix='pth'):
"""This function is for finding the latest checkpoint.

It will be used when automatically resume, modified from
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py

Args:
path (str): The path to find checkpoints.
suffix (str): File extension for the checkpoint. Defaults to pth.

Returns:
latest_path(str | None): File path of the latest checkpoint.
"""
if not osp.exists(path):
warnings.warn("The path of the checkpoints doesn't exist.")
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')

checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('The are no checkpoints in the path')
return None
latest = -1
latest_path = ''
for checkpoint in checkpoints:
if len(checkpoint) < len(latest_path):
continue
# `count` is iteration number, as checkpoints are saved as
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
40 changes: 40 additions & 0 deletions tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

from mmseg.utils import find_latest_checkpoint


def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tempdir:
# no checkpoints in the path
path = tempdir
latest = find_latest_checkpoint(path)
assert latest is None

# The path doesn't exist
path = osp.join(tempdir, 'none')
latest = find_latest_checkpoint(path)
assert latest is None

# test when latest.pth exists
with tempfile.TemporaryDirectory() as tempdir:
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
f.write('latest')
path = tempdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tempdir, 'latest.pth')

with tempfile.TemporaryDirectory() as tempdir:
for iter in range(1600, 160001, 1600):
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
f.write(f'iter_{iter}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'iter_160000.pth')

with tempfile.TemporaryDirectory() as tempdir:
for epoch in range(1, 21):
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
f.write(f'epoch_{epoch}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'epoch_20.pth')
5 changes: 5 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically.')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
Expand Down Expand Up @@ -118,6 +122,7 @@ def main():
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
cfg.auto_resume = args.auto_resume

# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
Expand Down