diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index dbb08b48ed..f41e4176cf 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -71,9 +71,17 @@ jobs: run: rm -rf .eggs && pip install -e . - name: Run unittests and generate coverage report run: | + pip install timm coverage run --branch --source mmseg -m pytest tests/ coverage xml coverage report -m + if: ${{matrix.torch >= '1.5.0'}} + - name: Skip timm unittests and generate coverage report + run: | + coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py + coverage xml + coverage report -m + if: ${{matrix.torch < '1.5.0'}} build_cuda101: runs-on: ubuntu-18.04 @@ -142,9 +150,17 @@ jobs: TORCH_CUDA_ARCH_LIST=7.0 pip install . - name: Run unittests and generate coverage report run: | + python -m pip install timm coverage run --branch --source mmseg -m pytest tests/ coverage xml coverage report -m + if: ${{matrix.torch >= '1.5.0'}} + - name: Skip timm unittests and generate coverage report + run: | + coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py + coverage xml + coverage report -m + if: ${{matrix.torch < '1.5.0'}} - name: Upload coverage to Codecov uses: codecov/codecov-action@v1.0.10 with: @@ -198,6 +214,7 @@ jobs: TORCH_CUDA_ARCH_LIST=7.0 pip install . - name: Run unittests and generate coverage report run: | + python -m pip install timm coverage run --branch --source mmseg -m pytest tests/ coverage xml coverage report -m diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 6d320323b8..408d3981dd 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -12,6 +12,7 @@ from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnext import ResNeXt from .swin import SwinTransformer +from .timm_backbone import TIMMBackbone from .unet import UNet from .vit import VisionTransformer @@ -19,5 +20,5 @@ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', - 'BiSeNetV1', 'BiSeNetV2', 'ICNet' + 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone' ] diff --git a/mmseg/models/backbones/timm_backbone.py b/mmseg/models/backbones/timm_backbone.py new file mode 100644 index 0000000000..01b29fc5ed --- /dev/null +++ b/mmseg/models/backbones/timm_backbone.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + import timm +except ImportError: + timm = None + +from mmcv.cnn.bricks.registry import NORM_LAYERS +from mmcv.runner import BaseModule + +from ..builder import BACKBONES + + +@BACKBONES.register_module() +class TIMMBackbone(BaseModule): + """Wrapper to use backbones from timm library. More details can be found in + `timm `_ . + + Args: + model_name (str): Name of timm model to instantiate. + pretrained (bool): Load pretrained weights if True. + checkpoint_path (str): Path of checkpoint to load after + model is initialized. + in_channels (int): Number of input image channels. Default: 3. + init_cfg (dict, optional): Initialization config dict + **kwargs: Other timm & model specific arguments. + """ + + def __init__( + self, + model_name, + features_only=True, + pretrained=True, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs, + ): + if timm is None: + raise RuntimeError('timm is not installed') + super(TIMMBackbone, self).__init__(init_cfg) + if 'norm_layer' in kwargs: + kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer']) + self.timm_model = timm.create_model( + model_name=model_name, + features_only=features_only, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs, + ) + + # Make unused parameters None + self.timm_model.global_pool = None + self.timm_model.fc = None + self.timm_model.classifier = None + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + def forward(self, x): + features = self.timm_model(x) + return features diff --git a/tests/test_models/test_backbones/test_timm_backbone.py b/tests/test_models/test_backbones/test_timm_backbone.py new file mode 100644 index 0000000000..85ef9aa56f --- /dev/null +++ b/tests/test_models/test_backbones/test_timm_backbone.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmseg.models.backbones import TIMMBackbone +from .utils import check_norm_state + + +def test_timm_backbone(): + with pytest.raises(TypeError): + # pretrained must be a string path + model = TIMMBackbone() + model.init_weights(pretrained=0) + + # Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN' + # Test resnet18 from timm, norm_layer='BN2d' + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32, + norm_layer='BN2d') + + # Test resnet18 from timm, norm_layer='SyncBN' + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32, + norm_layer='SyncBN') + + # Test resnet18 from timm, features_only=True, output_stride=32 + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32) + model.init_weights() + model.train() + assert check_norm_state(model.modules(), True) + + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 112, 112)) + assert feats[1] == torch.Size((1, 64, 56, 56)) + assert feats[2] == torch.Size((1, 128, 28, 28)) + assert feats[3] == torch.Size((1, 256, 14, 14)) + assert feats[4] == torch.Size((1, 512, 7, 7)) + + # Test resnet18 from timm, features_only=True, output_stride=16 + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=16) + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 112, 112)) + assert feats[1] == torch.Size((1, 64, 56, 56)) + assert feats[2] == torch.Size((1, 128, 28, 28)) + assert feats[3] == torch.Size((1, 256, 14, 14)) + assert feats[4] == torch.Size((1, 512, 14, 14)) + + # Test resnet18 from timm, features_only=True, output_stride=8 + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 112, 112)) + assert feats[1] == torch.Size((1, 64, 56, 56)) + assert feats[2] == torch.Size((1, 128, 28, 28)) + assert feats[3] == torch.Size((1, 256, 28, 28)) + assert feats[4] == torch.Size((1, 512, 28, 28)) + + # Test efficientnet_b1 with pretrained weights + model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True) + + # Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8 + model = TIMMBackbone( + model_name='resnetv2_50x1_bitm', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 8, 8) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 4, 4)) + assert feats[1] == torch.Size((1, 256, 2, 2)) + assert feats[2] == torch.Size((1, 512, 1, 1)) + assert feats[3] == torch.Size((1, 1024, 1, 1)) + assert feats[4] == torch.Size((1, 2048, 1, 1)) + + # Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8 + model = TIMMBackbone( + model_name='resnetv2_50x3_bitm', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 8, 8) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 192, 4, 4)) + assert feats[1] == torch.Size((1, 768, 2, 2)) + assert feats[2] == torch.Size((1, 1536, 1, 1)) + assert feats[3] == torch.Size((1, 3072, 1, 1)) + assert feats[4] == torch.Size((1, 6144, 1, 1)) + + # Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8 + model = TIMMBackbone( + model_name='resnetv2_101x1_bitm', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 8, 8) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 4, 4)) + assert feats[1] == torch.Size((1, 256, 2, 2)) + assert feats[2] == torch.Size((1, 512, 1, 1)) + assert feats[3] == torch.Size((1, 1024, 1, 1)) + assert feats[4] == torch.Size((1, 2048, 1, 1))