Skip to content

Commit

Permalink
Merge 8489c93 into a0eaf22
Browse files Browse the repository at this point in the history
  • Loading branch information
ckkelvinchan committed May 20, 2021
2 parents a0eaf22 + 8489c93 commit f1fc363
Show file tree
Hide file tree
Showing 8 changed files with 475 additions and 22 deletions.
37 changes: 30 additions & 7 deletions mmedit/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import numpy as np
import torch
from mmcv.parallel import MMDataParallel
from mmcv.runner import HOOKS, IterBasedRunner
from mmcv.runner import HOOKS, IterBasedRunner, build_runner

from mmedit.core import DistEvalIterHook, EvalIterHook, build_optimizers
from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper
from mmedit.core.runner.apex_amp_utils import apex_amp_initialize
from mmedit.datasets.builder import build_dataloader, build_dataset
from mmedit.utils import get_root_logger

Expand Down Expand Up @@ -133,12 +134,33 @@ def _dist_train(model,

# build runner
optimizer = build_optimizers(model, cfg.optimizers)
runner = IterBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)

# use apex amp
_use_apex_amp = False
if cfg.get('apex_amp', None):
model, optimizer = apex_amp_initialize(model, optimizer,
**cfg.apex_amp)
_use_apex_amp = True

# allow users to define the runner
if cfg.get('runner', None):
runner = build_runner(
cfg.runner,
dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
use_apex_amp=_use_apex_amp,
meta=meta))
else:
runner = IterBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)

# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp

Expand Down Expand Up @@ -189,6 +211,7 @@ def _dist_train(model,
runner.run(data_loaders, cfg.workflow, cfg.total_iters)


# TODO: Support fp16 for non-distributed training
def _non_dist_train(model,
dataset,
cfg,
Expand Down
3 changes: 2 additions & 1 deletion mmedit/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from .hooks import VisualizationHook
from .misc import tensor2img
from .optimizer import build_optimizers
from .runner import IterBasedFP16Runner
from .scheduler import LinearLrUpdaterHook

__all__ = [
'build_optimizers', 'tensor2img', 'EvalIterHook', 'DistEvalIterHook',
'mse', 'psnr', 'reorder_image', 'sad', 'ssim', 'LinearLrUpdaterHook',
'VisualizationHook', 'L1Evaluation'
'VisualizationHook', 'L1Evaluation', 'IterBasedFP16Runner'
]
3 changes: 3 additions & 0 deletions mmedit/core/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .iter_based_fp16_runner import IterBasedFP16Runner

__all__ = ['IterBasedFP16Runner']
33 changes: 33 additions & 0 deletions mmedit/core/runner/apex_amp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
try:
from apex import amp
except ImportError:
amp = None


def apex_amp_initialize(models, optimizers, init_args=None, mode='gan'):
"""Initialize apex.amp for mixed-precision training.
Args:
models (nn.Module | list[Module]): Modules to be wrapped with apex.amp.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
init_args (dict | None, optional): Config for amp initialization.
Defaults to None.
mode (str, optional): The moded used to initialize the apex.map.
Different modes lead to different wrapping mode for models and
optimizers. Defaults to 'gan'.
Returns:
Module, :obj:`Optimizer`: Wrapped module and optimizer.
"""
init_args = init_args or dict()

if mode == 'gan':
_optmizers = [optimizers['generator'], optimizers['discriminator']]

models, _optmizers = amp.initialize(models, _optmizers, **init_args)
optimizers['generator'] = _optmizers[0]
optimizers['discriminator'] = _optmizers[1]

return models, optimizers

else:
raise NotImplementedError(
f'Cannot initialize apex.amp with mode {mode}')
91 changes: 91 additions & 0 deletions mmedit/core/runner/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os.path as osp
import time
from tempfile import TemporaryDirectory

import mmcv
import torch
from mmcv.parallel import is_module_wrapper
from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
from torch.optim import Optimizer


def save_checkpoint(model,
filename,
optimizer=None,
loss_scaler=None,
save_apex_amp=False,
meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 or more fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
In mixed-precision training, ``loss_scaler`` or ``amp.state_dict`` will be
saved in checkpoint.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
loss_scaler (Object, optional): Loss scaler used for FP16 training.
save_apex_amp (bool, optional): Whether to save apex.amp
``state_dict``.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())

if is_module_wrapper(model):
model = model.module

if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)

checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()

# save loss scaler for mixed-precision (FP16) training
if loss_scaler is not None:
checkpoint['loss_scaler'] = loss_scaler.state_dict()

# save state_dict from apex.amp
if save_apex_amp:
from apex import amp
checkpoint['amp'] = amp.state_dict()

if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
Loading

0 comments on commit f1fc363

Please sign in to comment.