Skip to content

Commit

Permalink
Support features_only in TIMMBackbone
Browse files Browse the repository at this point in the history
  • Loading branch information
shinya7y committed Jan 23, 2022
1 parent e694269 commit b0a38ad
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 25 deletions.
101 changes: 79 additions & 22 deletions mmcls/models/backbones/timm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,109 @@
except ImportError:
timm = None

import warnings

from mmcv.cnn.bricks.registry import NORM_LAYERS

from ...utils import get_root_logger
from ..builder import BACKBONES
from .base_backbone import BaseBackbone


def print_timm_feature_info(feature_info):
"""Print feature_info of timm backbone to help development and debug.
Args:
feature_info (list[dict] | timm.models.features.FeatureInfo | None):
feature_info of timm backbone.
"""
logger = get_root_logger()
if feature_info is None:
logger.warning('This backbone does not have feature_info')
elif isinstance(feature_info, list):
for feat_idx, each_info in enumerate(feature_info):
logger.info(f'backbone feature_info[{feat_idx}]: {each_info}')
else:
try:
logger.info(f'backbone out_indices: {feature_info.out_indices}')
logger.info(f'backbone out_channels: {feature_info.channels()}')
logger.info(f'backbone out_strides: {feature_info.reduction()}')
except AttributeError:
logger.warning('Unexpected format of backbone feature_info')


@BACKBONES.register_module()
class TIMMBackbone(BaseBackbone):
"""Wrapper to use backbones from timm library. More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
"""Wrapper to use backbones from timm library.
More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_.
See especially the document for `feature extraction
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_.
Args:
model_name (str): Name of timm model to instantiate.
pretrained (bool): Load pretrained weights if True.
checkpoint_path (str): Path of checkpoint to load after
model is initialized.
in_channels (int): Number of input image channels. Default: 3.
init_cfg (dict, optional): Initialization config dict
features_only (bool, optional): Whether to extract feature pyramid
(multi-scale feature maps from the deepest layer at each stride).
For Vision Transformer models that do not support this argument,
set this False. Default: False.
pretrained (bool, optional): Whether to load pretrained weights.
Default: False.
checkpoint_path (str, optional): Path of checkpoint to load at the
last of timm.create_model. Default: '', which means not loading.
in_channels (int, optional): Number of input image channels.
Default: 3.
init_cfg (dict or list[dict], optional): Initialization config dict of
OpenMMLab projects. Default: None.
**kwargs: Other timm & model specific arguments.
"""

def __init__(
self,
model_name,
pretrained=False,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs,
):
def __init__(self,
model_name,
features_only=False,
pretrained=False,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs):
if timm is None:
raise RuntimeError('timm is not installed')
raise RuntimeError(
'Failed to import timm. Please run "pip install timm". '
'"pip install dataclasses" may also be needed for Python 3.6.')
if not isinstance(pretrained, bool):
raise TypeError('pretrained must be bool, not str for model path')
if features_only and checkpoint_path:
warnings.warn(
'Using both features_only and checkpoint_path will cause error'
' in timm. See '
'https://github.com/rwightman/pytorch-image-models/issues/488')

super(TIMMBackbone, self).__init__(init_cfg)
if 'norm_layer' in kwargs:
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs,
)
**kwargs)

# reset classifier
self.timm_model.reset_classifier(0, '')
if hasattr(self.timm_model, 'reset_classifier'):
self.timm_model.reset_classifier(0, '')

# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True

feature_info = getattr(self.timm_model, 'feature_info', None)
print_timm_feature_info(feature_info)

def forward(self, x):
features = self.timm_model.forward_features(x)
return (features, )
features = self.timm_model(x)
if isinstance(features, (list, tuple)):
features = tuple(features)
else:
features = (features, )
return features
150 changes: 147 additions & 3 deletions tests/test_models/test_backbones/test_timm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ def check_norm_state(modules, train_state):


def test_timm_backbone():
"""Test timm backbones, features_only=False (default)."""
with pytest.raises(TypeError):
# pretrained must be a string path
model = TIMMBackbone()
model.init_weights(pretrained=0)
# TIMMBackbone has 1 required positional argument: 'model_name'
model = TIMMBackbone(pretrained=True)

with pytest.raises(TypeError):
# pretrained must be bool
model = TIMMBackbone(model_name='resnet18', pretrained='model.pth')

# Test resnet18 from timm
model = TIMMBackbone(model_name='resnet18')
Expand Down Expand Up @@ -57,3 +61,143 @@ def test_timm_backbone():
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size((1, 192))


def test_timm_backbone_features_only():
"""Test timm backbones, features_only=True."""
# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN'
# Test resnet18 from timm, norm_layer='BN2d'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='BN2d')

# Test resnet18 from timm, norm_layer='SyncBN'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='SyncBN')

# Test resnet18 from timm, output_stride=32
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)

imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
assert len(feats) == 5
assert feats[0].shape == torch.Size((1, 64, 112, 112))
assert feats[1].shape == torch.Size((1, 64, 56, 56))
assert feats[2].shape == torch.Size((1, 128, 28, 28))
assert feats[3].shape == torch.Size((1, 256, 14, 14))
assert feats[4].shape == torch.Size((1, 512, 7, 7))

# Test resnet18 from timm, output_stride=32, out_indices=(1, 2, 3)
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
out_indices=(1, 2, 3))
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
assert len(feats) == 3
assert feats[0].shape == torch.Size((1, 64, 56, 56))
assert feats[1].shape == torch.Size((1, 128, 28, 28))
assert feats[2].shape == torch.Size((1, 256, 14, 14))

# Test resnet18 from timm, output_stride=16
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=16)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
assert len(feats) == 5
assert feats[0].shape == torch.Size((1, 64, 112, 112))
assert feats[1].shape == torch.Size((1, 64, 56, 56))
assert feats[2].shape == torch.Size((1, 128, 28, 28))
assert feats[3].shape == torch.Size((1, 256, 14, 14))
assert feats[4].shape == torch.Size((1, 512, 14, 14))

# Test resnet18 from timm, output_stride=8
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
assert len(feats) == 5
assert feats[0].shape == torch.Size((1, 64, 112, 112))
assert feats[1].shape == torch.Size((1, 64, 56, 56))
assert feats[2].shape == torch.Size((1, 128, 28, 28))
assert feats[3].shape == torch.Size((1, 256, 28, 28))
assert feats[4].shape == torch.Size((1, 512, 28, 28))

# Test efficientnet_b1 with pretrained weights
model = TIMMBackbone(
model_name='efficientnet_b1', features_only=True, pretrained=True)
imgs = torch.randn(1, 3, 64, 64)
feats = model(imgs)
assert len(feats) == 5
assert feats[0].shape == torch.Size((1, 16, 32, 32))
assert feats[1].shape == torch.Size((1, 24, 16, 16))
assert feats[2].shape == torch.Size((1, 40, 8, 8))
assert feats[3].shape == torch.Size((1, 112, 4, 4))
assert feats[4].shape == torch.Size((1, 320, 2, 2))

# Test resnetv2_50x1_bitm from timm, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
assert len(feats) == 5
assert feats[0].shape == torch.Size((1, 64, 4, 4))
assert feats[1].shape == torch.Size((1, 256, 2, 2))
assert feats[2].shape == torch.Size((1, 512, 1, 1))
assert feats[3].shape == torch.Size((1, 1024, 1, 1))
assert feats[4].shape == torch.Size((1, 2048, 1, 1))

# Test resnetv2_50x3_bitm from timm, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x3_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
assert len(feats) == 5
assert feats[0].shape == torch.Size((1, 192, 4, 4))
assert feats[1].shape == torch.Size((1, 768, 2, 2))
assert feats[2].shape == torch.Size((1, 1536, 1, 1))
assert feats[3].shape == torch.Size((1, 3072, 1, 1))
assert feats[4].shape == torch.Size((1, 6144, 1, 1))

# Test resnetv2_101x1_bitm from timm, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_101x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
assert len(feats) == 5
assert feats[0].shape == torch.Size((1, 64, 4, 4))
assert feats[1].shape == torch.Size((1, 256, 2, 2))
assert feats[2].shape == torch.Size((1, 512, 1, 1))
assert feats[3].shape == torch.Size((1, 1024, 1, 1))
assert feats[4].shape == torch.Size((1, 2048, 1, 1))

0 comments on commit b0a38ad

Please sign in to comment.