Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several Improvements for the latest PyTorch Framework #1564

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion mmengine/_strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ def compile_model(
Returns:
nn.Module: Compiled model.
"""
if isinstance(compile, bool) and not compile:
if isinstance(compile, bool) and not compile or \
isinstance(compile, dict) and not compile.get('disable', False):
return model

assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), (
Expand Down
3 changes: 3 additions & 0 deletions mmengine/model/wrappers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self,

def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
return self.module.train_step(data, optim_wrapper)
"""Interface for model forward, backward and parameters updating during
training process.

Expand Down Expand Up @@ -126,6 +127,7 @@ def train_step(self, data: Union[dict, tuple, list],
return log_vars

def val_step(self, data: Union[dict, tuple, list]) -> list:
return self.module.val_step(data)
"""Gets the prediction of module during validation process.

Args:
Expand All @@ -137,6 +139,7 @@ def val_step(self, data: Union[dict, tuple, list]) -> list:
return self.module.val_step(data)

def test_step(self, data: Union[dict, tuple, list]) -> list:
return self.module.test_step(data)
"""Gets the predictions of module during testing process.

Args:
Expand Down
4 changes: 3 additions & 1 deletion mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from functools import partial
from typing import Union

import torch
Expand All @@ -17,7 +18,8 @@
elif is_mlu_available():
from torch.mlu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler as amp_GradScaler
GradScaler = partial(amp_GradScaler, device='cuda')


@OPTIM_WRAPPERS.register_module()
Expand Down
1 change: 1 addition & 0 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,6 @@ def build_optim_wrapper(model: nn.Module,
type=constructor_type,
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg))

optim_wrapper = optim_wrapper_constructor(model)
return optim_wrapper
21 changes: 19 additions & 2 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from mmengine.dataset.sampler import InfiniteSampler
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
Expand Down Expand Up @@ -274,14 +275,28 @@ def run(self) -> None:
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
if self._iter > 0:
if self._iter > 0 and not isinstance(self.dataloader.sampler, InfiniteSampler):
print_log(
f'Advance dataloader {self._iter} steps to skip data '
'that has already been trained',
logger='current',
level=logging.WARNING)
for _ in range(self._iter):
break # NOTE MGAM: override all preprocessing steps during resume.
next(self.dataloader_iterator)

# with torch.profiler.profile(
# activities=[torch.profiler.ProfilerActivity.CPU,
# torch.profiler.ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_log'),
# record_shapes=True,
# profile_memory=True,
# with_stack=False,
# with_flops=True,
# with_modules=True,
# ) as p:

while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()

Expand All @@ -292,8 +307,10 @@ def run(self) -> None:
if (self.runner.val_loop is not None
and self._iter >= self.val_begin
and (self._iter % self.val_interval == 0
or self._iter == self._max_iters)):
or self._iter == self._max_iters)):
self.runner.val_loop.run()

# p.step()

self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train')
Expand Down
19 changes: 16 additions & 3 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import logging
import os
import os.path as osp
Expand Down Expand Up @@ -902,8 +903,20 @@ def wrap_model(
find_unused_parameters=find_unused_parameters)
else:
model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel')
model_wrapper_type = MODEL_WRAPPERS.get(
model_wrapper_cfg.get('type')) # type: ignore

model_wrapper_type = model_wrapper_cfg.get('type')
if isinstance(model_wrapper_type, str):
model_wrapper_type = MODEL_WRAPPERS.get(
model_wrapper_type) # type: ignore
elif inspect.isclass(model_wrapper_type):
pass
else:
raise KeyError(
f'{model_wrapper_type} is not in the '
'registry. Please check whether the value of '
f'`{model_wrapper_type}` is correct or it was registered '
'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
)
default_args: dict = dict()
if issubclass(
model_wrapper_type, # type: ignore
Expand Down Expand Up @@ -1838,7 +1851,7 @@ def call_hook(self, fn_name: str, **kwargs) -> None:
try:
getattr(hook, fn_name)(self, **kwargs)
except TypeError as e:
raise TypeError(f'{e} in {hook}') from None
raise TypeError(f'{e} in {hook}') from e

def register_hook(
self,
Expand Down