diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 5cd3ff24e7..f5afbb7f70 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -170,7 +170,7 @@ def __init__(self, with_cp=False, pretrained=None, init_cfg=None): - super(VisionTransformer, self).__init__() + super(VisionTransformer, self).__init__(init_cfg=init_cfg) if isinstance(img_size, int): img_size = to_2tuple(img_size) @@ -185,10 +185,13 @@ def __init__(self, assert with_cls_token is True, f'with_cls_token must be True if' \ f'set output_cls_token to True, but got {with_cls_token}' - if isinstance(pretrained, str) or pretrained is None: - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') - else: + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: raise TypeError('pretrained must be a str or None') self.img_size = img_size @@ -197,7 +200,6 @@ def __init__(self, self.norm_eval = norm_eval self.with_cp = with_cp self.pretrained = pretrained - self.init_cfg = init_cfg self.patch_embed = PatchEmbed( in_channels=in_channels, @@ -260,10 +262,12 @@ def norm1(self): return getattr(self, self.norm1_name) def init_weights(self): - if isinstance(self.pretrained, str): + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): logger = get_root_logger() checkpoint = _load_checkpoint( - self.pretrained, logger=logger, map_location='cpu') + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: @@ -283,9 +287,9 @@ def init_weights(self): (pos_size, pos_size), self.interpolate_mode) self.load_state_dict(state_dict, False) - - elif self.pretrained is None: + elif self.init_cfg is not None: super(VisionTransformer, self).init_weights() + else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 trunc_normal_init(self.pos_embed, std=.02) diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index 5dbb51e64a..4ce860c041 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -118,3 +118,59 @@ def test_vit_backbone(): feat = model(imgs) assert feat[0][0].shape == (1, 768, 14, 14) assert feat[0][1].shape == (1, 768) + + +def test_vit_init(): + path = 'PATH_THAT_DO_NOT_EXIST' + # Test all combinations of pretrained and init_cfg + # pretrained=None, init_cfg=None + model = VisionTransformer(pretrained=None, init_cfg=None) + assert model.init_cfg is None + model.init_weights() + + # pretrained=None + # init_cfg loads pretrain from an non-existent file + model = VisionTransformer( + pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) + assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + # Test loading a checkpoint from an non-existent file + with pytest.raises(OSError): + model.init_weights() + + # pretrained=None + # init_cfg=123, whose type is unsupported + model = VisionTransformer(pretrained=None, init_cfg=123) + with pytest.raises(TypeError): + model.init_weights() + + # pretrained loads pretrain from an non-existent file + # init_cfg=None + model = VisionTransformer(pretrained=path, init_cfg=None) + assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + # Test loading a checkpoint from an non-existent file + with pytest.raises(OSError): + model.init_weights() + + # pretrained loads pretrain from an non-existent file + # init_cfg loads pretrain from an non-existent file + with pytest.raises(AssertionError): + model = VisionTransformer( + pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) + with pytest.raises(AssertionError): + model = VisionTransformer(pretrained=path, init_cfg=123) + + # pretrain=123, whose type is unsupported + # init_cfg=None + with pytest.raises(TypeError): + model = VisionTransformer(pretrained=123, init_cfg=None) + + # pretrain=123, whose type is unsupported + # init_cfg loads pretrain from an non-existent file + with pytest.raises(AssertionError): + model = VisionTransformer( + pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) + + # pretrain=123, whose type is unsupported + # init_cfg=123, whose type is unsupported + with pytest.raises(AssertionError): + model = VisionTransformer(pretrained=123, init_cfg=123)