diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index a713da9a70..b555df9e94 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -322,7 +322,8 @@ def compile_model( Returns: nn.Module: Compiled model. """ - if isinstance(compile, bool) and not compile: + if isinstance(compile, bool) and not compile or \ + isinstance(compile, dict) and not compile.get('disable', False): return model assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), ( diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 4113aebf9e..b88bc7c2b0 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -95,6 +95,7 @@ def __init__(self, def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + return self.module.train_step(data, optim_wrapper) """Interface for model forward, backward and parameters updating during training process. @@ -126,6 +127,7 @@ def train_step(self, data: Union[dict, tuple, list], return log_vars def val_step(self, data: Union[dict, tuple, list]) -> list: + return self.module.val_step(data) """Gets the prediction of module during validation process. Args: @@ -137,6 +139,7 @@ def val_step(self, data: Union[dict, tuple, list]) -> list: return self.module.val_step(data) def test_step(self, data: Union[dict, tuple, list]) -> list: + return self.module.test_step(data) """Gets the predictions of module during testing process. Args: diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 4f3323f2cc..60200924b5 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager +from functools import partial from typing import Union import torch @@ -17,7 +18,8 @@ elif is_mlu_available(): from torch.mlu.amp import GradScaler else: - from torch.cuda.amp import GradScaler + from torch.amp import GradScaler as amp_GradScaler + GradScaler = partial(amp_GradScaler, device='cuda') @OPTIM_WRAPPERS.register_module() diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 8557f4d34c..b57ebc315a 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -207,5 +207,6 @@ def build_optim_wrapper(model: nn.Module, type=constructor_type, optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) + optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 5a678db7b9..7be8995781 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -12,6 +12,7 @@ from mmengine.registry import LOOPS from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of +from mmengine.dataset.sampler import InfiniteSampler from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -274,14 +275,28 @@ def run(self) -> None: # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') - if self._iter > 0: + if self._iter > 0 and not isinstance(self.dataloader.sampler, InfiniteSampler): print_log( f'Advance dataloader {self._iter} steps to skip data ' 'that has already been trained', logger='current', level=logging.WARNING) for _ in range(self._iter): + break # NOTE MGAM: override all preprocessing steps during resume. next(self.dataloader_iterator) + + # with torch.profiler.profile( + # activities=[torch.profiler.ProfilerActivity.CPU, + # torch.profiler.ProfilerActivity.CUDA], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3), + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_log'), + # record_shapes=True, + # profile_memory=True, + # with_stack=False, + # with_flops=True, + # with_modules=True, + # ) as p: + while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() @@ -292,8 +307,10 @@ def run(self) -> None: if (self.runner.val_loop is not None and self._iter >= self.val_begin and (self._iter % self.val_interval == 0 - or self._iter == self._max_iters)): + or self._iter == self._max_iters)): self.runner.val_loop.run() + + # p.step() self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 7d1f655aad..f89fb260a1 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import inspect import logging import os import os.path as osp @@ -902,8 +903,20 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') - model_wrapper_type = MODEL_WRAPPERS.get( - model_wrapper_cfg.get('type')) # type: ignore + + model_wrapper_type = model_wrapper_cfg.get('type') + if isinstance(model_wrapper_type, str): + model_wrapper_type = MODEL_WRAPPERS.get( + model_wrapper_type) # type: ignore + elif inspect.isclass(model_wrapper_type): + pass + else: + raise KeyError( + f'{model_wrapper_type} is not in the ' + 'registry. Please check whether the value of ' + f'`{model_wrapper_type}` is correct or it was registered ' + 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 + ) default_args: dict = dict() if issubclass( model_wrapper_type, # type: ignore @@ -1838,7 +1851,7 @@ def call_hook(self, fn_name: str, **kwargs) -> None: try: getattr(hook, fn_name)(self, **kwargs) except TypeError as e: - raise TypeError(f'{e} in {hook}') from None + raise TypeError(f'{e} in {hook}') from e def register_hook( self,