Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix the bug that vit cannot load pretrain properly when using i… #999

Merged
merged 9 commits into from
Nov 3, 2021
27 changes: 19 additions & 8 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(self,
with_cp=False,
pretrained=None,
init_cfg=None):
super(VisionTransformer, self).__init__()
super(VisionTransformer, self).__init__(init_cfg)
RockeyCoss marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(img_size, int):
img_size = to_2tuple(img_size)
Expand All @@ -185,10 +185,13 @@ def __init__(self,
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'

if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
else:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')

self.img_size = img_size
Expand All @@ -197,7 +200,6 @@ def __init__(self,
self.norm_eval = norm_eval
self.with_cp = with_cp
self.pretrained = pretrained
self.init_cfg = init_cfg

self.patch_embed = PatchEmbed(
in_channels=in_channels,
Expand Down Expand Up @@ -260,10 +262,19 @@ def norm1(self):
return getattr(self, self.norm1_name)

def init_weights(self):
if isinstance(self.pretrained, str):
if (isinstance(self.pretrained, str)
or isinstance(self.init_cfg, dict) and 'type' in self.init_cfg
and self.init_cfg['type'] == 'Pretrained'):
logger = get_root_logger()
checkpoint = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if self.pretrained:
checkpoint = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
else:
checkpoint = _load_checkpoint(
self.init_cfg['checkpoint'],
logger=logger,
map_location='cpu')

if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_models/test_backbones/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,44 @@ def test_vit_backbone():
feat = model(imgs)
assert feat[0][0].shape == (1, 768, 14, 14)
assert feat[0][1].shape == (1, 768)


def test_vit_init():
path = 'PATH_THAT_DO_NOT_EXIST'
# Test all combinations of pretrained and init_cfg
RockeyCoss marked this conversation as resolved.
Show resolved Hide resolved
model = VisionTransformer(pretrained=None, init_cfg=None)
assert model.init_cfg is None
model.init_weights()

model = VisionTransformer(
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()

model = VisionTransformer(pretrained=None, init_cfg=123)
RockeyCoss marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(TypeError):
model.init_weights()

model = VisionTransformer(pretrained=path, init_cfg=None)
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()

with pytest.raises(AssertionError):
model = VisionTransformer(
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
with pytest.raises(AssertionError):
model = VisionTransformer(pretrained=path, init_cfg=123)

with pytest.raises(TypeError):
model = VisionTransformer(pretrained=123, init_cfg=None)

with pytest.raises(AssertionError):
model = VisionTransformer(
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))

with pytest.raises(AssertionError):
model = VisionTransformer(pretrained=123, init_cfg=123)
RockeyCoss marked this conversation as resolved.
Show resolved Hide resolved