Skip to content

Commit

Permalink
[Enhance] EMAHook support does not load checkpoint strictly (open-mml…
Browse files Browse the repository at this point in the history
…ab#352)

* BaseAveragedModel support load ckpt without module prefix

* refine docstring

* allow EMAHook does not load ckpt strictly

* add unit test for strict argument of EMAHook

* sync remote

* sync remote

* clean the code

* ema hook supports setting start iter

* fix unit test

* fix as comment

* fix as comment

* describe kwargs
  • Loading branch information
HAOCHENYE authored Aug 8, 2022
1 parent 08602a2 commit 99de095
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 30 deletions.
122 changes: 92 additions & 30 deletions mmengine/hooks/ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,44 @@ class EMAHook(Hook):
- EMAHook takes priority over CheckpointHook.
- The original model parameters are actually saved in ema field after
train.
- ``begin_iter`` and ``begin_epoch`` cannot be set at the same time.
Args:
ema_type (str): The type of EMA strategy to use. You can find the
supported strategies in ``mmengine.model.averaged_model``.
Defaults to 'ExponentialMovingAverage'
supported strategies in :mod:`mmengine.model.averaged_model`.
Defaults to 'ExponentialMovingAverage'.
strict_load (bool): Whether to strictly enforce that the keys of
``state_dict`` in checkpoint match the keys returned by
``self.module.state_dict``. Defaults to True.
begin_iter (int): The number of iteration to enable ``EMAHook``.
Defaults to 0.
begin_epoch (int): The number of epoch to enable ``EMAHook``. Defaults
to 0.
**kwargs: Keyword arguments passed to subclasses of
:obj:`BaseAveragedModel`
"""

priority = 'NORMAL'

def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs):
def __init__(self,
ema_type: str = 'ExponentialMovingAverage',
strict_load: bool = True,
begin_iter: int = 0,
begin_epoch: int = 0,
**kwargs):
self.strict_load = strict_load
self.ema_cfg = dict(type=ema_type, **kwargs)
assert not (begin_iter != 0 and begin_epoch != 0), (
'`begin_iter` and `begin_epoch` should not be both set.')
assert begin_iter >= 0, (
f'begin_iter must larger than 0, but got begin: {begin_iter}')
assert begin_epoch >= 0, (
f'begin_epoch must larger than 0, but got begin: {begin_epoch}')
self.begin_iter = begin_iter
self.begin_epoch = begin_epoch
# If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be
# enabled at 0 iteration.
self.enabled_by_epoch = self.begin_epoch > 0

def before_run(self, runner) -> None:
"""Create an ema copy of the model."""
Expand All @@ -40,64 +67,81 @@ def before_run(self, runner) -> None:
self.ema_model = MODELS.build(
self.ema_cfg, default_args=dict(model=self.src_model))

if self.enabled_by_epoch:
assert self.begin_epoch <= runner.max_epochs, (
'self.begin_epoch should be smaller than runner.max_epochs: '
f'{runner.max_epochs}, but got begin: {self.begin_epoch}')
else:
assert self.begin_iter <= runner.max_iters, (
'self.begin_iter should be smaller than runner.max_iters: '
f'{runner.max_iters}, but got begin: {self.begin_iter}')

def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Update ema parameter."""
self.ema_model.update_parameters(self.src_model)
if self._ema_started(runner):
self.ema_model.update_parameters(self.src_model)

def before_val_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before
validation."""
self._swap_ema_parameters()
if self._ema_started(runner):
self._swap_ema_parameters()

def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""We recover source model's parameter from ema model after
validation."""
self._swap_ema_parameters()
if self._ema_started(runner):
self._swap_ema_parameters()

def before_test_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before
test."""
self._swap_ema_parameters()
if self._ema_started(runner):
self._swap_ema_parameters()

def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""We recover source model's parameter from ema model after test."""
self._swap_ema_parameters()
if self._ema_started(runner):
self._swap_ema_parameters()

def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""Save ema parameters to checkpoint."""
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
# Save ema parameters to the source model's state dict so that we can
# directly load the averaged model weights for deployment.
# Swapping the state_dict key-values instead of swapping model
# parameters because the state_dict is a shallow copy of model
# parameters.
self._swap_ema_state_dict(checkpoint)
if self._ema_started(runner):
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
# Save ema parameters to the source model's state dict so that we
# can directly load the averaged model weights for deployment.
# Swapping the state_dict key-values instead of swapping model
# parameters because the state_dict is a shallow copy of model
# parameters.
self._swap_ema_state_dict(checkpoint)

def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""Resume ema parameters from checkpoint."""

if 'ema_state_dict' in checkpoint:
# The original model parameters are actually saved in ema field.
# swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint)
self.ema_model.load_state_dict(checkpoint['ema_state_dict'])

# Support load checkpoint without ema state dict.
else:
print_log(
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
self.ema_model.module.load_state_dict(
copy.deepcopy(checkpoint['state_dict']))
if self._ema_started(runner):
if 'ema_state_dict' in checkpoint:
# The original model parameters are actually saved in ema
# field swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint)
self.ema_model.load_state_dict(
checkpoint['ema_state_dict'], strict=self.strict_load)

# Support load checkpoint without ema state dict.
else:
print_log(
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
self.ema_model.module.load_state_dict(
copy.deepcopy(checkpoint['state_dict']),
strict=self.strict_load)

def _swap_ema_parameters(self) -> None:
"""Swap the parameter of model with ema_model."""
Expand All @@ -124,3 +168,21 @@ def _swap_ema_state_dict(self, checkpoint):
tmp = ema_state[k]
ema_state[k] = model_state[k[7:]]
model_state[k[7:]] = tmp

def _ema_started(self, runner) -> bool:
"""Whether ``EMAHook`` has been initialized at current iteration or
epoch.
:attr:`ema_model` will be initialized when ``runner.iter`` or
``runner.epoch`` is greater than ``self.begin`` for the first time.
Args:
runner (Runner): Runner of the training, validation process.
Returns:
bool: Whether ``EMAHook`` has been initialized.
"""
if self.enabled_by_epoch:
return runner.epoch + 1 >= self.begin_epoch
else:
return runner.iter + 1 >= self.begin_iter
86 changes: 86 additions & 0 deletions tests/test_hook/test_ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def forward(self, *args, **kwargs):
return super(BaseModel, self).forward(*args, **kwargs)


class ToyModel2(BaseModel, ToyModel):

def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 1)

def forward(self, *args, **kwargs):
return super(BaseModel, self).forward(*args, **kwargs)


@DATASETS.register_module()
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
Expand Down Expand Up @@ -171,3 +181,79 @@ def forward(self, *args, **kwargs):
custom_hooks=[dict(type='EMAHook')],
experiment_name='test4')
runner.test()

# Test does not load ckpt strict_loadly.
# Test load checkpoint without ema_state_dict
runner = Runner(
model=ToyModel2(),
test_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', strict_load=False)],
experiment_name='test5')
runner.test()

# Test enable ema at 5 epochs.
runner = Runner(
model=model,
train_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=evaluator,
work_dir=self.temp_dir.name,
optim_wrapper=OptimWrapper(
torch.optim.Adam(ToyModel().parameters())),
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
val_cfg=dict(),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
experiment_name='test6')
runner.train()
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'))
self.assertNotIn('ema_state_dict', state_dict)
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth'))
self.assertIn('ema_state_dict', state_dict)

# Test enable ema at 5 iterations.
runner = Runner(
model=model,
train_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=evaluator,
work_dir=self.temp_dir.name,
optim_wrapper=OptimWrapper(
torch.optim.Adam(ToyModel().parameters())),
train_cfg=dict(by_epoch=False, max_iters=10, val_interval=1),
val_cfg=dict(),
default_hooks=dict(
checkpoint=dict(
type='CheckpointHook', interval=1, by_epoch=False)),
custom_hooks=[dict(type='EMAHook', begin_iter=5)],
experiment_name='test7')
runner.train()
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'))
self.assertNotIn('ema_state_dict', state_dict)
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'))
self.assertIn('ema_state_dict', state_dict)

0 comments on commit 99de095

Please sign in to comment.