Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support fp16 training #320

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions mmedit/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import numpy as np
import torch
from mmcv.parallel import MMDataParallel
from mmcv.runner import HOOKS, IterBasedRunner
from mmcv.runner import HOOKS, IterBasedRunner, build_runner

from mmedit.core import DistEvalIterHook, EvalIterHook, build_optimizers
from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper
from mmedit.core.runner.apex_amp_utils import apex_amp_initialize
from mmedit.datasets.builder import build_dataloader, build_dataset
from mmedit.utils import get_root_logger

Expand Down Expand Up @@ -133,12 +134,33 @@ def _dist_train(model,

# build runner
optimizer = build_optimizers(model, cfg.optimizers)
runner = IterBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)

# use apex amp
_use_apex_amp = False
if cfg.get('apex_amp', None):
model, optimizer = apex_amp_initialize(model, optimizer,
**cfg.apex_amp)
_use_apex_amp = True

# allow users to define the runner
if cfg.get('runner', None):
runner = build_runner(
cfg.runner,
dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
use_apex_amp=_use_apex_amp,
meta=meta))
else:
runner = IterBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)

# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp

Expand Down Expand Up @@ -189,6 +211,7 @@ def _dist_train(model,
runner.run(data_loaders, cfg.workflow, cfg.total_iters)


# TODO: Support fp16 for non-distributed training
def _non_dist_train(model,
dataset,
cfg,
Expand Down
3 changes: 2 additions & 1 deletion mmedit/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from .hooks import VisualizationHook
from .misc import tensor2img
from .optimizer import build_optimizers
from .runner import IterBasedFP16Runner
from .scheduler import LinearLrUpdaterHook

__all__ = [
'build_optimizers', 'tensor2img', 'EvalIterHook', 'DistEvalIterHook',
'mse', 'psnr', 'reorder_image', 'sad', 'ssim', 'LinearLrUpdaterHook',
'VisualizationHook', 'L1Evaluation'
'VisualizationHook', 'L1Evaluation', 'IterBasedFP16Runner'
]
3 changes: 3 additions & 0 deletions mmedit/core/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .iter_based_fp16_runner import IterBasedFP16Runner

__all__ = ['IterBasedFP16Runner']
33 changes: 33 additions & 0 deletions mmedit/core/runner/apex_amp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
try:
from apex import amp
except ImportError:
amp = None


def apex_amp_initialize(models, optimizers, init_args=None, mode='gan'):
"""Initialize apex.amp for mixed-precision training.
Args:
models (nn.Module | list[Module]): Modules to be wrapped with apex.amp.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
init_args (dict | None, optional): Config for amp initialization.
Defaults to None.
mode (str, optional): The moded used to initialize the apex.map.
Different modes lead to different wrapping mode for models and
optimizers. Defaults to 'gan'.
Returns:
Module, :obj:`Optimizer`: Wrapped module and optimizer.
"""
init_args = init_args or dict()

if mode == 'gan':
_optmizers = [optimizers['generator'], optimizers['discriminator']]

models, _optmizers = amp.initialize(models, _optmizers, **init_args)
optimizers['generator'] = _optmizers[0]
optimizers['discriminator'] = _optmizers[1]

return models, optimizers

else:
raise NotImplementedError(
f'Cannot initialize apex.amp with mode {mode}')
91 changes: 91 additions & 0 deletions mmedit/core/runner/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os.path as osp
import time
from tempfile import TemporaryDirectory

import mmcv
import torch
from mmcv.parallel import is_module_wrapper
from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
from torch.optim import Optimizer


def save_checkpoint(model,
filename,
optimizer=None,
loss_scaler=None,
save_apex_amp=False,
meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 or more fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
In mixed-precision training, ``loss_scaler`` or ``amp.state_dict`` will be
saved in checkpoint.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
loss_scaler (Object, optional): Loss scaler used for FP16 training.
save_apex_amp (bool, optional): Whether to save apex.amp
``state_dict``.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())

if is_module_wrapper(model):
model = model.module

if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)

checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()

# save loss scaler for mixed-precision (FP16) training
if loss_scaler is not None:
checkpoint['loss_scaler'] = loss_scaler.state_dict()

# save state_dict from apex.amp
if save_apex_amp:
from apex import amp
checkpoint['amp'] = amp.state_dict()

if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
Loading