From e874e04609968e28312e0ab36249dcaa5059cda0 Mon Sep 17 00:00:00 2001 From: yhna Date: Sun, 26 Jun 2022 18:23:05 +0900 Subject: [PATCH 1/9] Fix checkpointing --- mmtune/mm/context/rewriters/resume.py | 7 +++++-- mmtune/mm/hooks/checkpoint.py | 29 +++++++++------------------ 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/mmtune/mm/context/rewriters/resume.py b/mmtune/mm/context/rewriters/resume.py index 7f885430..a41176d7 100644 --- a/mmtune/mm/context/rewriters/resume.py +++ b/mmtune/mm/context/rewriters/resume.py @@ -1,3 +1,4 @@ +from os import path as osp from typing import Dict from .base import BaseRewriter @@ -25,6 +26,8 @@ def __call__(self, context: Dict) -> Dict: Returns: Dict: The context after rewriting. """ - setattr( - context.get('args'), self.arg_name, context.pop('checkpoint_dir')) + if context.get('checkpoint_dir') is not None: + setattr( + context.get('args'), self.arg_name, + osp.join(context.pop('checkpoint_dir'), 'ray_ckpt.pth')) return context diff --git a/mmtune/mm/hooks/checkpoint.py b/mmtune/mm/hooks/checkpoint.py index 0d8c161b..fc20da80 100644 --- a/mmtune/mm/hooks/checkpoint.py +++ b/mmtune/mm/hooks/checkpoint.py @@ -45,10 +45,6 @@ def __init__(self, saved regardless of interval. Default: True. sync_buffer (bool, optional): Whether to synchronize buffers in different gpus. Default: False. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmcv.fileio.FileClient` for details. - Default: None. - `New in version 1.3.16.` """ self.interval = interval self.by_epoch = by_epoch @@ -57,25 +53,18 @@ def __init__(self, self.save_last = save_last self.args = kwargs self.sync_buffer = sync_buffer - self.file_client_args = file_client_args """Save checkpoints periodically.""" - def get_iter(self, runner: BaseRunner, inner_iter: bool = False): - """Get the current iteration. + def before_run(self, runner: BaseRunner): + """This hook omits the setting process because it gets information from + the ray session. Args: runner (:obj:`mmcv.runner.BaseRunner`): - The runner to get the current iteration. - inner_iter (bool): - Whether to get the inner iteration. + The runner. """ - - if self.by_epoch and inner_iter: - current_iter = runner.inner_iter + 1 - else: - current_iter = runner.iter + 1 - return current_iter + pass @master_only def _save_checkpoint(self, runner: BaseRunner) -> None: @@ -92,7 +81,7 @@ def _save_checkpoint(self, runner: BaseRunner) -> None: mmcv_version=mmcv.__version__, time=time.asctime(), epoch=runner.epoch + 1, - iter=runner.iter) + iter=runner.iter + 1) if is_module_wrapper(model): model = model.module if hasattr(model, 'CLASSES') and model.CLASSES is not None: @@ -111,6 +100,8 @@ def _save_checkpoint(self, runner: BaseRunner) -> None: checkpoint['optimizer'][name] = optim.state_dict() with distributed_checkpoint_dir( - step=self.get_iter(runner)) as checkpoint_dir: - path = os.path.join(checkpoint_dir, 'ray_checkpoint.pth') + step=(runner.epoch + 1) // + self.interval if self.by_epoch else (runner.iter + 1) // + self.interval) as checkpoint_dir: + path = os.path.join(checkpoint_dir, 'ray_ckpt.pth') torch.save(checkpoint, path) From 3bbaada1d5f62b2bf0348e7ad4f09363e3cf4b49 Mon Sep 17 00:00:00 2001 From: yhna Date: Sun, 26 Jun 2022 21:55:50 +0900 Subject: [PATCH 2/9] Fix argparse --- docs/get_started.md | 2 +- tools/tune.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/get_started.md b/docs/get_started.md index c779a2f2..7af10eb0 100644 --- a/docs/get_started.md +++ b/docs/get_started.md @@ -34,5 +34,5 @@ python tools/tune.py ${TUNE_CONFIG} [optional tune arguments] [optional task arg ```bash # MMDetection Example -python tools/tune.py configs/mmtune/mmdet_asynchb_nevergrad_pso.py configs/mmdet/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py +python tools/tune.py configs/mmtune/mmdet_asynchb_nevergrad_pso.py --trainable_args configs/mmdet/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py ``` diff --git a/tools/tune.py b/tools/tune.py index 3d900ea0..04798b08 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -54,7 +54,7 @@ def parse_args() -> Namespace: help='number of gpus each worker uses.', ) parser.add_argument( - 'trainable_args', + '--trainable_args', nargs=REMAINDER, type=str, help='Rest from the trainable process.', From dd618d289f3ddedd3b59e956fcbe53f67602c60f Mon Sep 17 00:00:00 2001 From: Younghwan Na <100389977+yhna940@users.noreply.github.com> Date: Mon, 27 Jun 2022 16:26:12 +0900 Subject: [PATCH 3/9] Update tune.py Signed-off-by: Younghwan Na <100389977+yhna940@users.noreply.github.com> --- tools/tune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/tune.py b/tools/tune.py index 04798b08..5aa62b63 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -54,7 +54,7 @@ def parse_args() -> Namespace: help='number of gpus each worker uses.', ) parser.add_argument( - '--trainable_args', + '--trainable-args', nargs=REMAINDER, type=str, help='Rest from the trainable process.', From 044336035a2a03f3aa66bbaec59cf7ae9e9a47df Mon Sep 17 00:00:00 2001 From: yhna Date: Mon, 27 Jun 2022 21:39:43 +0900 Subject: [PATCH 4/9] Fix pbt --- mmtune/ray/schedulers/__init__.py | 3 ++- mmtune/ray/schedulers/pbt.py | 14 ++++++++++++++ setup.cfg | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 mmtune/ray/schedulers/pbt.py diff --git a/mmtune/ray/schedulers/__init__.py b/mmtune/ray/schedulers/__init__.py index e90342f0..07c26c5d 100644 --- a/mmtune/ray/schedulers/__init__.py +++ b/mmtune/ray/schedulers/__init__.py @@ -1,3 +1,4 @@ from .builder import SCHEDULERS, build_scheduler +from .pbt import PopulationBasedTraining -__all__ = ['SCHEDULERS', 'build_scheduler'] +__all__ = ['SCHEDULERS', 'build_scheduler', 'PopulationBasedTraining'] diff --git a/mmtune/ray/schedulers/pbt.py b/mmtune/ray/schedulers/pbt.py new file mode 100644 index 00000000..d5252c03 --- /dev/null +++ b/mmtune/ray/schedulers/pbt.py @@ -0,0 +1,14 @@ +from mmtun.ray.scheduler import SCHEDULERS +from ray.tune.schedulers.pbt import \ + PopulationBasedTraining as _PopulationBasedTraining + +from mmtune.ray.spaces import build_space + + +@SCHEDULERS.register_module(force=True) +class PopulationBasedTraining(_PopulationBasedTraining): + + def __init__(self, *args, **kwargs) -> None: + hyperparam_mutations = kwargs.get('hyperparam_mutations', + dict()).copy() + kwargs.update(hyperparam_mutations=build_space(hyperparam_mutations)) diff --git a/setup.cfg b/setup.cfg index c112c949..765b0f97 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,6 @@ line_length = 79 multi_line_output = 0 extra_standard_library = setuptools known_first_party = mmtune -known_third_party = gpytorch,mmcv,numpy,pytest,ray,setuptools,torch +known_third_party = gpytorch,mmcv,mmtun,numpy,pytest,ray,setuptools,torch no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY From 2aa67bad3f2e54589118566f8ccccabab2db7b71 Mon Sep 17 00:00:00 2001 From: Younghwan Na <100389977+yhna940@users.noreply.github.com> Date: Tue, 28 Jun 2022 10:15:45 +0900 Subject: [PATCH 5/9] Fix pbt --- mmtune/ray/schedulers/pbt.py | 52 ++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/mmtune/ray/schedulers/pbt.py b/mmtune/ray/schedulers/pbt.py index d5252c03..9dfc3db8 100644 --- a/mmtune/ray/schedulers/pbt.py +++ b/mmtune/ray/schedulers/pbt.py @@ -1,8 +1,51 @@ +import copy +import random +from typing import Callable, Dict, Optional + from mmtun.ray.scheduler import SCHEDULERS +from ray.tune.sample import Domain from ray.tune.schedulers.pbt import \ PopulationBasedTraining as _PopulationBasedTraining from mmtune.ray.spaces import build_space +from mmtune.utils import ImmutableContainer + + +def explore( + config: Dict, + mutations: Dict, + resample_probability: float, + custom_explore_fn: Optional[Callable], +) -> Dict: + """Return a config perturbed as specified. + + Args: + config: Original hyperparameter configuration. + mutations: Specification of mutations to perform as documented + in the PopulationBasedTraining scheduler. + resample_probability: Probability of allowing resampling of a + particular variable. + custom_explore_fn: Custom explore fn applied after built-in + config perturbations are. + """ + new_config = copy.deepcopy(config) + for key, distribution in mutations.items(): + assert isinstance(distribution, Domain) + if random.random() < resample_probability: + new_config[key] = ImmutableContainer.decouple( + distribution.sample(None)) + + try: + new_config[key] = config[key] * 1.2 if random.random( + ) > 0.5 else config[key] * 0.8 + if isinstance(config[key], int): + new_config[key] = int(new_config[key]) + except Exception: + new_config[key] = config[key] + if custom_explore_fn: + new_config = custom_explore_fn(new_config) + assert new_config is not None + return new_config @SCHEDULERS.register_module(force=True) @@ -12,3 +55,12 @@ def __init__(self, *args, **kwargs) -> None: hyperparam_mutations = kwargs.get('hyperparam_mutations', dict()).copy() kwargs.update(hyperparam_mutations=build_space(hyperparam_mutations)) + + def _get_new_config(self, trial, trial_to_clone): + """Gets new config for trial by exploring trial_to_clone's config.""" + return explore( + trial_to_clone.config, + self._hyperparam_mutations, + self._resample_probability, + self._custom_explore_fn, + ) From ee04d500089c53d91d6dca1392bb026843bea44d Mon Sep 17 00:00:00 2001 From: Younghwan Na <100389977+yhna940@users.noreply.github.com> Date: Tue, 28 Jun 2022 10:36:28 +0900 Subject: [PATCH 6/9] Fix minor bug --- mmtune/ray/schedulers/pbt.py | 3 ++- setup.cfg | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mmtune/ray/schedulers/pbt.py b/mmtune/ray/schedulers/pbt.py index 9dfc3db8..d3484a44 100644 --- a/mmtune/ray/schedulers/pbt.py +++ b/mmtune/ray/schedulers/pbt.py @@ -2,11 +2,11 @@ import random from typing import Callable, Dict, Optional -from mmtun.ray.scheduler import SCHEDULERS from ray.tune.sample import Domain from ray.tune.schedulers.pbt import \ PopulationBasedTraining as _PopulationBasedTraining +from mmtune.ray.schedulers import SCHEDULERS from mmtune.ray.spaces import build_space from mmtune.utils import ImmutableContainer @@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs) -> None: hyperparam_mutations = kwargs.get('hyperparam_mutations', dict()).copy() kwargs.update(hyperparam_mutations=build_space(hyperparam_mutations)) + super().__init__(*args, **kwargs) def _get_new_config(self, trial, trial_to_clone): """Gets new config for trial by exploring trial_to_clone's config.""" diff --git a/setup.cfg b/setup.cfg index 765b0f97..c112c949 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,6 @@ line_length = 79 multi_line_output = 0 extra_standard_library = setuptools known_first_party = mmtune -known_third_party = gpytorch,mmcv,mmtun,numpy,pytest,ray,setuptools,torch +known_third_party = gpytorch,mmcv,numpy,pytest,ray,setuptools,torch no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY From 7ae37c176c2d13fd82415cb176a725187dd94b20 Mon Sep 17 00:00:00 2001 From: Younghwan Na <100389977+yhna940@users.noreply.github.com> Date: Wed, 29 Jun 2022 11:42:45 +0900 Subject: [PATCH 7/9] Fix test code --- tests/test_mm/test_hooks.py | 5 ----- tests/test_mm/test_rewriters.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_mm/test_hooks.py b/tests/test_mm/test_hooks.py index d929c314..9e309357 100644 --- a/tests/test_mm/test_hooks.py +++ b/tests/test_mm/test_hooks.py @@ -23,11 +23,6 @@ def test_raycheckpointhook(): mock_runner.iter = 5 mock_runner.epoch = 5 - cur_iter = hook.get_iter(mock_runner, False) - assert cur_iter == 6 - cur_iter = hook.get_iter(mock_runner, True) - assert cur_iter == 4 - mock_runner.model = torch.nn.Linear(2, 2) mock_runner.optimizer = torch.optim.Adam(mock_runner.model.parameters()) diff --git a/tests/test_mm/test_rewriters.py b/tests/test_mm/test_rewriters.py index 5c1790a2..b1a11e64 100644 --- a/tests/test_mm/test_rewriters.py +++ b/tests/test_mm/test_rewriters.py @@ -128,4 +128,4 @@ def test_resume_ckpt(): resume_from_ckpt = ResumeFromCkpt() context = resume_from_ckpt(context) - assert context.get('args').resume_from == 'test' + assert context.get('args').resume_from == 'test/ray_ckpt.pth' From 058ea316b8375e290e64b3b6cbb41bd59cbe4b7b Mon Sep 17 00:00:00 2001 From: Younghwan Na <100389977+yhna940@users.noreply.github.com> Date: Wed, 29 Jun 2022 13:22:47 +0900 Subject: [PATCH 8/9] Fix minor test code --- tests/test_mm/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mm/test_hooks.py b/tests/test_mm/test_hooks.py index 9e309357..c0f5194c 100644 --- a/tests/test_mm/test_hooks.py +++ b/tests/test_mm/test_hooks.py @@ -27,7 +27,7 @@ def test_raycheckpointhook(): mock_runner.optimizer = torch.optim.Adam(mock_runner.model.parameters()) hook._save_checkpoint(mock_runner) - assert os.path.exists('ray_checkpoint.pth') + assert os.path.exists('ray_ckpt.pth') @patch.object(RayTuneLoggerHook, 'get_loggable_tags') From 8439094e3821f11e1d371b447593cff51ad40352 Mon Sep 17 00:00:00 2001 From: Younghwan Na <100389977+yhna940@users.noreply.github.com> Date: Wed, 29 Jun 2022 14:31:14 +0900 Subject: [PATCH 9/9] Fix gpytorch in test workflow --- .github/workflows/build_unit_test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_unit_test.yaml b/.github/workflows/build_unit_test.yaml index 7c3d50b7..03ebaf29 100644 --- a/.github/workflows/build_unit_test.yaml +++ b/.github/workflows/build_unit_test.yaml @@ -26,7 +26,7 @@ jobs: run: | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html pip install mmcv-full==1.4.7 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html - pip install gpytorch nevergrad mmcls mmdet mmsegmentation protobuf==3.20 + pip install gpytorch==0.3.6 nevergrad mmcls mmdet mmsegmentation protobuf==3.20 pip install -e . - name: Run unittests and generate coverage report run: |