diff --git a/mmcls/models/backbones/timm_backbone.py b/mmcls/models/backbones/timm_backbone.py index 2e88d6057a0..58bc17eb7ac 100644 --- a/mmcls/models/backbones/timm_backbone.py +++ b/mmcls/models/backbones/timm_backbone.py @@ -4,52 +4,109 @@ except ImportError: timm = None +import warnings + +from mmcv.cnn.bricks.registry import NORM_LAYERS + +from ...utils import get_root_logger from ..builder import BACKBONES from .base_backbone import BaseBackbone +def print_timm_feature_info(feature_info): + """Print feature_info of timm backbone to help development and debug. + + Args: + feature_info (list[dict] | timm.models.features.FeatureInfo | None): + feature_info of timm backbone. + """ + logger = get_root_logger() + if feature_info is None: + logger.warning('This backbone does not have feature_info') + elif isinstance(feature_info, list): + for feat_idx, each_info in enumerate(feature_info): + logger.info(f'backbone feature_info[{feat_idx}]: {each_info}') + else: + try: + logger.info(f'backbone out_indices: {feature_info.out_indices}') + logger.info(f'backbone out_channels: {feature_info.channels()}') + logger.info(f'backbone out_strides: {feature_info.reduction()}') + except AttributeError: + logger.warning('Unexpected format of backbone feature_info') + + @BACKBONES.register_module() class TIMMBackbone(BaseBackbone): - """Wrapper to use backbones from timm library. More details can be found in - `timm `_ . + """Wrapper to use backbones from timm library. + + More details can be found in + `timm `_. + See especially the document for `feature extraction + `_. Args: model_name (str): Name of timm model to instantiate. - pretrained (bool): Load pretrained weights if True. - checkpoint_path (str): Path of checkpoint to load after - model is initialized. - in_channels (int): Number of input image channels. Default: 3. - init_cfg (dict, optional): Initialization config dict + features_only (bool, optional): Whether to extract feature pyramid + (multi-scale feature maps from the deepest layer at each stride). + For Vision Transformer models that do not support this argument, + set this False. Default: False. + pretrained (bool, optional): Whether to load pretrained weights. + Default: False. + checkpoint_path (str, optional): Path of checkpoint to load at the + last of timm.create_model. Default: '', which means not loading. + in_channels (int, optional): Number of input image channels. + Default: 3. + init_cfg (dict or list[dict], optional): Initialization config dict of + OpenMMLab projects. Default: None. **kwargs: Other timm & model specific arguments. """ - def __init__( - self, - model_name, - pretrained=False, - checkpoint_path='', - in_channels=3, - init_cfg=None, - **kwargs, - ): + def __init__(self, + model_name, + features_only=False, + pretrained=False, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs): if timm is None: - raise RuntimeError('timm is not installed') + raise RuntimeError( + 'Failed to import timm. Please run "pip install timm". ' + '"pip install dataclasses" may also be needed for Python 3.6.') + if not isinstance(pretrained, bool): + raise TypeError('pretrained must be bool, not str for model path') + if features_only and checkpoint_path: + warnings.warn( + 'Using both features_only and checkpoint_path will cause error' + ' in timm. See ' + 'https://github.com/rwightman/pytorch-image-models/issues/488') + super(TIMMBackbone, self).__init__(init_cfg) + if 'norm_layer' in kwargs: + kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer']) self.timm_model = timm.create_model( model_name=model_name, + features_only=features_only, pretrained=pretrained, in_chans=in_channels, checkpoint_path=checkpoint_path, - **kwargs, - ) + **kwargs) # reset classifier - self.timm_model.reset_classifier(0, '') + if hasattr(self.timm_model, 'reset_classifier'): + self.timm_model.reset_classifier(0, '') # Hack to use pretrained weights from timm if pretrained or checkpoint_path: self._is_init = True + feature_info = getattr(self.timm_model, 'feature_info', None) + print_timm_feature_info(feature_info) + def forward(self, x): - features = self.timm_model.forward_features(x) - return (features, ) + features = self.timm_model(x) + if isinstance(features, (list, tuple)): + features = tuple(features) + else: + features = (features, ) + return features diff --git a/tests/test_models/test_backbones/test_timm_backbone.py b/tests/test_models/test_backbones/test_timm_backbone.py index 1ab06879c6b..4c6ae925dbe 100644 --- a/tests/test_models/test_backbones/test_timm_backbone.py +++ b/tests/test_models/test_backbones/test_timm_backbone.py @@ -17,10 +17,14 @@ def check_norm_state(modules, train_state): def test_timm_backbone(): + """Test timm backbones, features_only=False (default).""" with pytest.raises(TypeError): - # pretrained must be a string path - model = TIMMBackbone() - model.init_weights(pretrained=0) + # TIMMBackbone has 1 required positional argument: 'model_name' + model = TIMMBackbone(pretrained=True) + + with pytest.raises(TypeError): + # pretrained must be bool + model = TIMMBackbone(model_name='resnet18', pretrained='model.pth') # Test resnet18 from timm model = TIMMBackbone(model_name='resnet18') @@ -57,3 +61,143 @@ def test_timm_backbone(): feat = model(imgs) assert len(feat) == 1 assert feat[0].shape == torch.Size((1, 192)) + + +def test_timm_backbone_features_only(): + """Test timm backbones, features_only=True.""" + # Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN' + # Test resnet18 from timm, norm_layer='BN2d' + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32, + norm_layer='BN2d') + + # Test resnet18 from timm, norm_layer='SyncBN' + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32, + norm_layer='SyncBN') + + # Test resnet18 from timm, output_stride=32 + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32) + model.init_weights() + model.train() + assert check_norm_state(model.modules(), True) + + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + assert len(feats) == 5 + assert feats[0].shape == torch.Size((1, 64, 112, 112)) + assert feats[1].shape == torch.Size((1, 64, 56, 56)) + assert feats[2].shape == torch.Size((1, 128, 28, 28)) + assert feats[3].shape == torch.Size((1, 256, 14, 14)) + assert feats[4].shape == torch.Size((1, 512, 7, 7)) + + # Test resnet18 from timm, output_stride=32, out_indices=(1, 2, 3) + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32, + out_indices=(1, 2, 3)) + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + assert len(feats) == 3 + assert feats[0].shape == torch.Size((1, 64, 56, 56)) + assert feats[1].shape == torch.Size((1, 128, 28, 28)) + assert feats[2].shape == torch.Size((1, 256, 14, 14)) + + # Test resnet18 from timm, output_stride=16 + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=16) + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + assert len(feats) == 5 + assert feats[0].shape == torch.Size((1, 64, 112, 112)) + assert feats[1].shape == torch.Size((1, 64, 56, 56)) + assert feats[2].shape == torch.Size((1, 128, 28, 28)) + assert feats[3].shape == torch.Size((1, 256, 14, 14)) + assert feats[4].shape == torch.Size((1, 512, 14, 14)) + + # Test resnet18 from timm, output_stride=8 + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + assert len(feats) == 5 + assert feats[0].shape == torch.Size((1, 64, 112, 112)) + assert feats[1].shape == torch.Size((1, 64, 56, 56)) + assert feats[2].shape == torch.Size((1, 128, 28, 28)) + assert feats[3].shape == torch.Size((1, 256, 28, 28)) + assert feats[4].shape == torch.Size((1, 512, 28, 28)) + + # Test efficientnet_b1 with pretrained weights + model = TIMMBackbone( + model_name='efficientnet_b1', features_only=True, pretrained=True) + imgs = torch.randn(1, 3, 64, 64) + feats = model(imgs) + assert len(feats) == 5 + assert feats[0].shape == torch.Size((1, 16, 32, 32)) + assert feats[1].shape == torch.Size((1, 24, 16, 16)) + assert feats[2].shape == torch.Size((1, 40, 8, 8)) + assert feats[3].shape == torch.Size((1, 112, 4, 4)) + assert feats[4].shape == torch.Size((1, 320, 2, 2)) + + # Test resnetv2_50x1_bitm from timm, output_stride=8 + model = TIMMBackbone( + model_name='resnetv2_50x1_bitm', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 8, 8) + feats = model(imgs) + assert len(feats) == 5 + assert feats[0].shape == torch.Size((1, 64, 4, 4)) + assert feats[1].shape == torch.Size((1, 256, 2, 2)) + assert feats[2].shape == torch.Size((1, 512, 1, 1)) + assert feats[3].shape == torch.Size((1, 1024, 1, 1)) + assert feats[4].shape == torch.Size((1, 2048, 1, 1)) + + # Test resnetv2_50x3_bitm from timm, output_stride=8 + model = TIMMBackbone( + model_name='resnetv2_50x3_bitm', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 8, 8) + feats = model(imgs) + assert len(feats) == 5 + assert feats[0].shape == torch.Size((1, 192, 4, 4)) + assert feats[1].shape == torch.Size((1, 768, 2, 2)) + assert feats[2].shape == torch.Size((1, 1536, 1, 1)) + assert feats[3].shape == torch.Size((1, 3072, 1, 1)) + assert feats[4].shape == torch.Size((1, 6144, 1, 1)) + + # Test resnetv2_101x1_bitm from timm, output_stride=8 + model = TIMMBackbone( + model_name='resnetv2_101x1_bitm', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 8, 8) + feats = model(imgs) + assert len(feats) == 5 + assert feats[0].shape == torch.Size((1, 64, 4, 4)) + assert feats[1].shape == torch.Size((1, 256, 2, 2)) + assert feats[2].shape == torch.Size((1, 512, 1, 1)) + assert feats[3].shape == torch.Size((1, 1024, 1, 1)) + assert feats[4].shape == torch.Size((1, 2048, 1, 1))