Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checkpoint bug & Support PBT #53

Merged
merged 10 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_unit_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
run: |
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full==1.4.7 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
pip install gpytorch nevergrad mmcls mmdet mmsegmentation protobuf==3.20
pip install gpytorch==0.3.6 nevergrad mmcls mmdet mmsegmentation protobuf==3.20
pip install -e .
- name: Run unittests and generate coverage report
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ python tools/tune.py ${TUNE_CONFIG} [optional tune arguments] [optional task arg

```bash
# MMDetection Example
python tools/tune.py configs/mmtune/mmdet_asynchb_nevergrad_pso.py configs/mmdet/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py
python tools/tune.py configs/mmtune/mmdet_asynchb_nevergrad_pso.py --trainable_args configs/mmdet/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py
```
7 changes: 5 additions & 2 deletions mmtune/mm/context/rewriters/resume.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os import path as osp
from typing import Dict

from .base import BaseRewriter
Expand Down Expand Up @@ -25,6 +26,8 @@ def __call__(self, context: Dict) -> Dict:
Returns:
Dict: The context after rewriting.
"""
setattr(
context.get('args'), self.arg_name, context.pop('checkpoint_dir'))
if context.get('checkpoint_dir') is not None:
setattr(
context.get('args'), self.arg_name,
osp.join(context.pop('checkpoint_dir'), 'ray_ckpt.pth'))
return context
29 changes: 10 additions & 19 deletions mmtune/mm/hooks/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ def __init__(self,
saved regardless of interval. Default: True.
sync_buffer (bool, optional): Whether to synchronize buffers in
different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
"""
self.interval = interval
self.by_epoch = by_epoch
Expand All @@ -57,25 +53,18 @@ def __init__(self,
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args

"""Save checkpoints periodically."""

def get_iter(self, runner: BaseRunner, inner_iter: bool = False):
"""Get the current iteration.
def before_run(self, runner: BaseRunner):
"""This hook omits the setting process because it gets information from
the ray session.

Args:
runner (:obj:`mmcv.runner.BaseRunner`):
The runner to get the current iteration.
inner_iter (bool):
Whether to get the inner iteration.
The runner.
"""

if self.by_epoch and inner_iter:
current_iter = runner.inner_iter + 1
else:
current_iter = runner.iter + 1
return current_iter
pass

@master_only
def _save_checkpoint(self, runner: BaseRunner) -> None:
Expand All @@ -92,7 +81,7 @@ def _save_checkpoint(self, runner: BaseRunner) -> None:
mmcv_version=mmcv.__version__,
time=time.asctime(),
epoch=runner.epoch + 1,
iter=runner.iter)
iter=runner.iter + 1)
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
Expand All @@ -111,6 +100,8 @@ def _save_checkpoint(self, runner: BaseRunner) -> None:
checkpoint['optimizer'][name] = optim.state_dict()

with distributed_checkpoint_dir(
step=self.get_iter(runner)) as checkpoint_dir:
path = os.path.join(checkpoint_dir, 'ray_checkpoint.pth')
step=(runner.epoch + 1) //
self.interval if self.by_epoch else (runner.iter + 1) //
self.interval) as checkpoint_dir:
path = os.path.join(checkpoint_dir, 'ray_ckpt.pth')
torch.save(checkpoint, path)
3 changes: 2 additions & 1 deletion mmtune/ray/schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .builder import SCHEDULERS, build_scheduler
from .pbt import PopulationBasedTraining

__all__ = ['SCHEDULERS', 'build_scheduler']
__all__ = ['SCHEDULERS', 'build_scheduler', 'PopulationBasedTraining']
67 changes: 67 additions & 0 deletions mmtune/ray/schedulers/pbt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import copy
import random
from typing import Callable, Dict, Optional

from ray.tune.sample import Domain
from ray.tune.schedulers.pbt import \
PopulationBasedTraining as _PopulationBasedTraining

from mmtune.ray.schedulers import SCHEDULERS
from mmtune.ray.spaces import build_space
from mmtune.utils import ImmutableContainer


def explore(
config: Dict,
mutations: Dict,
resample_probability: float,
custom_explore_fn: Optional[Callable],
) -> Dict:
"""Return a config perturbed as specified.

Args:
config: Original hyperparameter configuration.
mutations: Specification of mutations to perform as documented
in the PopulationBasedTraining scheduler.
resample_probability: Probability of allowing resampling of a
particular variable.
custom_explore_fn: Custom explore fn applied after built-in
config perturbations are.
"""
new_config = copy.deepcopy(config)
for key, distribution in mutations.items():
assert isinstance(distribution, Domain)
if random.random() < resample_probability:
new_config[key] = ImmutableContainer.decouple(
distribution.sample(None))

try:
new_config[key] = config[key] * 1.2 if random.random(
) > 0.5 else config[key] * 0.8
if isinstance(config[key], int):
new_config[key] = int(new_config[key])
except Exception:
new_config[key] = config[key]
if custom_explore_fn:
new_config = custom_explore_fn(new_config)
assert new_config is not None
return new_config


@SCHEDULERS.register_module(force=True)
class PopulationBasedTraining(_PopulationBasedTraining):

def __init__(self, *args, **kwargs) -> None:
hyperparam_mutations = kwargs.get('hyperparam_mutations',
dict()).copy()
kwargs.update(hyperparam_mutations=build_space(hyperparam_mutations))
super().__init__(*args, **kwargs)

def _get_new_config(self, trial, trial_to_clone):
"""Gets new config for trial by exploring trial_to_clone's config."""
return explore(
trial_to_clone.config,
self._hyperparam_mutations,
self._resample_probability,
self._custom_explore_fn,
)
2 changes: 1 addition & 1 deletion mmtune/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def parse_args() -> Namespace:
help='number of gpus each worker uses.',
)
parser.add_argument(
'trainable_args',
'--trainable-args',
nargs=REMAINDER,
type=str,
help='Rest from the trainable process.',
Expand Down
7 changes: 1 addition & 6 deletions tests/test_mm/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@ def test_raycheckpointhook():
mock_runner.iter = 5
mock_runner.epoch = 5

cur_iter = hook.get_iter(mock_runner, False)
assert cur_iter == 6
cur_iter = hook.get_iter(mock_runner, True)
assert cur_iter == 4

mock_runner.model = torch.nn.Linear(2, 2)
mock_runner.optimizer = torch.optim.Adam(mock_runner.model.parameters())

hook._save_checkpoint(mock_runner)
assert os.path.exists('ray_checkpoint.pth')
assert os.path.exists('ray_ckpt.pth')


@patch.object(RayTuneLoggerHook, 'get_loggable_tags')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mm/test_rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ def test_resume_ckpt():

resume_from_ckpt = ResumeFromCkpt()
context = resume_from_ckpt(context)
assert context.get('args').resume_from == 'test'
assert context.get('args').resume_from == 'test/ray_ckpt.pth'