diff --git a/configs/mmtune/mmseg_asynchb_nevergrad_pso.py b/configs/mmtune/mmseg_asynchb_nevergrad_pso.py index 335d1307..a9990181 100644 --- a/configs/mmtune/mmseg_asynchb_nevergrad_pso.py +++ b/configs/mmtune/mmseg_asynchb_nevergrad_pso.py @@ -7,7 +7,7 @@ space = { 'model': {{_base_.model}}, 'optimizer': {{_base_.optimizer}}, - 'data.samples_per_gpus': {{_base_.batch_size}} + 'data.samples_per_gpu': {{_base_.batch_size}} } metric = 'val/mIoU' diff --git a/mmtune/mm/context/rewriters/register.py b/mmtune/mm/context/rewriters/register.py index 7fafe53b..4c158131 100644 --- a/mmtune/mm/context/rewriters/register.py +++ b/mmtune/mm/context/rewriters/register.py @@ -10,7 +10,8 @@ def __init__(self, post_custom_hooks: List[str]) -> None: self.post_custom_hooks = post_custom_hooks def __call__(self, context: dict) -> dict: - custom_hooks = getattr(context['cfg'], 'custom_hooks', []) + custom_hooks = getattr(context['cfg'], 'custom_hooks', []).copy() for custom_hook in self.post_custom_hooks: custom_hooks.append(custom_hook) + context['cfg'].custom_hooks = custom_hooks return context diff --git a/mmtune/mm/tasks/mmtrainbase.py b/mmtune/mm/tasks/mmtrainbase.py index 5fcc76ec..f969010a 100644 --- a/mmtune/mm/tasks/mmtrainbase.py +++ b/mmtune/mm/tasks/mmtrainbase.py @@ -35,6 +35,7 @@ def train_model(cls, model: torch.nn.Module, @classmethod def contextaware_run(cls, status, backend, *args, **kwargs) -> None: + from mmtune.mm.tasks import hooks # noqa F401 if backend == 'nccl' and os.getenv('NCCL_BLOCKING_WAIT') is None: os.environ['NCCL_BLOCKING_WAIT'] = '0' context_manager = ContextManager(**status) @@ -53,5 +54,5 @@ def create_trainable(cls, backend: str = 'nccl') -> ray.tune.trainable: rewriters=cls.REWRITERS), backend), backend=backend, num_workers=cls.ARGS.num_workers, - num_gpus_per_worker=cls.ARGS.num_cpus_per_worker, + num_gpus_per_worker=cls.ARGS.num_gpus_per_worker, num_cpus_per_worker=cls.ARGS.num_cpus_per_worker)