Skip to content

Commit

Permalink
[Feature] Support revise_keys in load_checkpoint(). (#829)
Browse files Browse the repository at this point in the history
* Simplified the code.

* Improved chkpt compatibility.

* One may modify the checkpoint via adding keywords.

* Tiny.

* Following reviewer's suggestion.

* Added unit_test.

* Fixed.

* Modify the state_dict  with  construction.

* Added test.

* Modified。

* Mimimalised the modification.

* Added the docstring.

* Format.

* Improved.

* Tiny.

* Temp file.

* Added assertion.

* Doc string.

* Fixed.
  • Loading branch information
yaochaorui authored Mar 3, 2021
1 parent 34b552b commit e076c8b
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 8 deletions.
1 change: 1 addition & 0 deletions mmcv/cnn/utils/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ class PretrainedInit(object):
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
map_location (str): map tensors into proper locations.
"""

def __init__(self, checkpoint, prefix=None, map_location=None):
Expand Down
17 changes: 13 additions & 4 deletions mmcv/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def __init__(self,
self.optimizer = optimizer
self.logger = logger
self.meta = meta

# create work_dir
if mmcv.is_str(work_dir):
self.work_dir = osp.abspath(work_dir)
Expand Down Expand Up @@ -307,10 +306,20 @@ def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)

def load_checkpoint(self, filename, map_location='cpu', strict=False):
def load_checkpoint(self,
filename,
map_location='cpu',
strict=False,
revise_keys=[(r'^module.', '')]):

self.logger.info('load checkpoint from %s', filename)
return load_checkpoint(self.model, filename, map_location, strict,
self.logger)
return load_checkpoint(
self.model,
filename,
map_location,
strict,
self.logger,
revise_keys=revise_keys)

def resume(self,
checkpoint,
Expand Down
13 changes: 10 additions & 3 deletions mmcv/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import os.path as osp
import pkgutil
import re
import time
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -503,7 +504,8 @@ def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None):
logger=None,
revise_keys=[(r'^module\.', '')]):
"""Load checkpoint from a file or URI.
Args:
Expand All @@ -515,6 +517,11 @@ def load_checkpoint(model,
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
revise_keys (list): A list of customized keywords to modify the
state_dict in checkpoint. Each item is a (pattern, replacement)
pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')].
Returns:
dict or OrderedDict: The loaded checkpoint.
Expand All @@ -530,8 +537,8 @@ def load_checkpoint(model,
else:
state_dict = checkpoint
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
for p, r in revise_keys:
state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
Expand Down
35 changes: 34 additions & 1 deletion tests/test_runner/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_from_pavi)
get_state_dict, load_checkpoint,
load_from_pavi)


@MODULE_WRAPPERS.register_module()
Expand Down Expand Up @@ -188,6 +189,38 @@ def __init__(self):
_load_checkpoint_with_prefix(prefix, 'model.pth')


def test_load_checkpoint():
import os
import tempfile
import re

class PrefixModel(nn.Module):

def __init__(self):
super().__init__()
self.backbone = Model()

pmodel = PrefixModel()
model = Model()
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')

# add prefix
torch.save(model.state_dict(), checkpoint_path)
state_dict = load_checkpoint(
pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')])
for key in pmodel.backbone.state_dict().keys():
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
# strip prefix
torch.save(pmodel.state_dict(), checkpoint_path)
state_dict = load_checkpoint(
model, checkpoint_path, revise_keys=[(r'^backbone\.', '')])

for key in state_dict.keys():
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)


def test_load_classes_name():
import os

Expand Down
40 changes: 40 additions & 0 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import logging
import os.path as osp
import re
import shutil
import sys
import tempfile
Expand Down Expand Up @@ -415,3 +416,42 @@ def val_step(self, x, optimizer, **kwargs):
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner


def test_runner_with_revise_keys():

import os

class Model(nn.Module):

def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)

class PrefixModel(nn.Module):

def __init__(self):
super().__init__()
self.backbone = Model()

pmodel = PrefixModel()
model = Model()
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')

# add prefix
torch.save(model.state_dict(), checkpoint_path)
runner = _build_demo_runner(runner_type='EpochBasedRunner')
runner.model = pmodel
state_dict = runner.load_checkpoint(
checkpoint_path, revise_keys=[(r'^', 'backbone.')])
for key in pmodel.backbone.state_dict().keys():
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
# strip prefix
torch.save(pmodel.state_dict(), checkpoint_path)
runner.model = model
state_dict = runner.load_checkpoint(
checkpoint_path, revise_keys=[(r'^backbone\.', '')])
for key in state_dict.keys():
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)

0 comments on commit e076c8b

Please sign in to comment.