-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support TIMMBackbone #998
Changes from all commits
064e1f9
ec8ffae
57e3983
3fa7650
2a8503c
eeca97a
831f6ef
e4b2cb0
f25eea5
b75c201
bc7e80e
c4cbe44
a09d948
68bf65f
5eced88
500f22e
0db1bd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,9 +71,17 @@ jobs: | |
run: rm -rf .eggs && pip install -e . | ||
- name: Run unittests and generate coverage report | ||
run: | | ||
pip install timm | ||
coverage run --branch --source mmseg -m pytest tests/ | ||
coverage xml | ||
coverage report -m | ||
if: ${{matrix.torch >= '1.5.0'}} | ||
- name: Skip timm unittests and generate coverage report | ||
run: | | ||
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could undo the ignore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TIMM does not support pt1.3 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
coverage xml | ||
coverage report -m | ||
if: ${{matrix.torch < '1.5.0'}} | ||
|
||
build_cuda101: | ||
runs-on: ubuntu-18.04 | ||
|
@@ -142,9 +150,17 @@ jobs: | |
TORCH_CUDA_ARCH_LIST=7.0 pip install . | ||
- name: Run unittests and generate coverage report | ||
run: | | ||
python -m pip install timm | ||
coverage run --branch --source mmseg -m pytest tests/ | ||
coverage xml | ||
coverage report -m | ||
if: ${{matrix.torch >= '1.5.0'}} | ||
- name: Skip timm unittests and generate coverage report | ||
run: | | ||
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py | ||
coverage xml | ||
coverage report -m | ||
if: ${{matrix.torch < '1.5.0'}} | ||
- name: Upload coverage to Codecov | ||
uses: codecov/codecov-action@v1.0.10 | ||
with: | ||
|
@@ -198,6 +214,7 @@ jobs: | |
TORCH_CUDA_ARCH_LIST=7.0 pip install . | ||
- name: Run unittests and generate coverage report | ||
run: | | ||
python -m pip install timm | ||
coverage run --branch --source mmseg -m pytest tests/ | ||
coverage xml | ||
coverage report -m | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
try: | ||
import timm | ||
except ImportError: | ||
timm = None | ||
|
||
from mmcv.cnn.bricks.registry import NORM_LAYERS | ||
from mmcv.runner import BaseModule | ||
|
||
from ..builder import BACKBONES | ||
|
||
|
||
@BACKBONES.register_module() | ||
class TIMMBackbone(BaseModule): | ||
"""Wrapper to use backbones from timm library. More details can be found in | ||
`timm <https://github.com/rwightman/pytorch-image-models>`_ . | ||
|
||
Args: | ||
model_name (str): Name of timm model to instantiate. | ||
pretrained (bool): Load pretrained weights if True. | ||
checkpoint_path (str): Path of checkpoint to load after | ||
model is initialized. | ||
in_channels (int): Number of input image channels. Default: 3. | ||
init_cfg (dict, optional): Initialization config dict | ||
**kwargs: Other timm & model specific arguments. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name, | ||
features_only=True, | ||
pretrained=True, | ||
checkpoint_path='', | ||
in_channels=3, | ||
init_cfg=None, | ||
**kwargs, | ||
): | ||
if timm is None: | ||
raise RuntimeError('timm is not installed') | ||
super(TIMMBackbone, self).__init__(init_cfg) | ||
if 'norm_layer' in kwargs: | ||
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer']) | ||
self.timm_model = timm.create_model( | ||
model_name=model_name, | ||
features_only=features_only, | ||
pretrained=pretrained, | ||
in_chans=in_channels, | ||
checkpoint_path=checkpoint_path, | ||
**kwargs, | ||
) | ||
|
||
# Make unused parameters None | ||
self.timm_model.global_pool = None | ||
self.timm_model.fc = None | ||
self.timm_model.classifier = None | ||
|
||
# Hack to use pretrained weights from timm | ||
if pretrained or checkpoint_path: | ||
self._is_init = True | ||
|
||
def forward(self, x): | ||
features = self.timm_model(x) | ||
return features |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import pytest | ||
import torch | ||
|
||
from mmseg.models.backbones import TIMMBackbone | ||
from .utils import check_norm_state | ||
|
||
|
||
def test_timm_backbone(): | ||
with pytest.raises(TypeError): | ||
# pretrained must be a string path | ||
model = TIMMBackbone() | ||
model.init_weights(pretrained=0) | ||
|
||
# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN' | ||
# Test resnet18 from timm, norm_layer='BN2d' | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=32, | ||
norm_layer='BN2d') | ||
|
||
# Test resnet18 from timm, norm_layer='SyncBN' | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=32, | ||
norm_layer='SyncBN') | ||
|
||
# Test resnet18 from timm, features_only=True, output_stride=32 | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=32) | ||
model.init_weights() | ||
model.train() | ||
assert check_norm_state(model.modules(), True) | ||
|
||
imgs = torch.randn(1, 3, 224, 224) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 112, 112)) | ||
assert feats[1] == torch.Size((1, 64, 56, 56)) | ||
assert feats[2] == torch.Size((1, 128, 28, 28)) | ||
assert feats[3] == torch.Size((1, 256, 14, 14)) | ||
assert feats[4] == torch.Size((1, 512, 7, 7)) | ||
|
||
# Test resnet18 from timm, features_only=True, output_stride=16 | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=16) | ||
imgs = torch.randn(1, 3, 224, 224) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 112, 112)) | ||
assert feats[1] == torch.Size((1, 64, 56, 56)) | ||
assert feats[2] == torch.Size((1, 128, 28, 28)) | ||
assert feats[3] == torch.Size((1, 256, 14, 14)) | ||
assert feats[4] == torch.Size((1, 512, 14, 14)) | ||
|
||
# Test resnet18 from timm, features_only=True, output_stride=8 | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=8) | ||
imgs = torch.randn(1, 3, 224, 224) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 112, 112)) | ||
assert feats[1] == torch.Size((1, 64, 56, 56)) | ||
assert feats[2] == torch.Size((1, 128, 28, 28)) | ||
assert feats[3] == torch.Size((1, 256, 28, 28)) | ||
assert feats[4] == torch.Size((1, 512, 28, 28)) | ||
|
||
# Test efficientnet_b1 with pretrained weights | ||
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True) | ||
|
||
# Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8 | ||
model = TIMMBackbone( | ||
model_name='resnetv2_50x1_bitm', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=8) | ||
imgs = torch.randn(1, 3, 8, 8) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 4, 4)) | ||
assert feats[1] == torch.Size((1, 256, 2, 2)) | ||
assert feats[2] == torch.Size((1, 512, 1, 1)) | ||
assert feats[3] == torch.Size((1, 1024, 1, 1)) | ||
assert feats[4] == torch.Size((1, 2048, 1, 1)) | ||
|
||
# Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8 | ||
model = TIMMBackbone( | ||
model_name='resnetv2_50x3_bitm', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=8) | ||
imgs = torch.randn(1, 3, 8, 8) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 192, 4, 4)) | ||
assert feats[1] == torch.Size((1, 768, 2, 2)) | ||
assert feats[2] == torch.Size((1, 1536, 1, 1)) | ||
assert feats[3] == torch.Size((1, 3072, 1, 1)) | ||
assert feats[4] == torch.Size((1, 6144, 1, 1)) | ||
|
||
# Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8 | ||
model = TIMMBackbone( | ||
model_name='resnetv2_101x1_bitm', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=8) | ||
imgs = torch.randn(1, 3, 8, 8) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 4, 4)) | ||
assert feats[1] == torch.Size((1, 256, 2, 2)) | ||
assert feats[2] == torch.Size((1, 512, 1, 1)) | ||
assert feats[3] == torch.Size((1, 1024, 1, 1)) | ||
assert feats[4] == torch.Size((1, 2048, 1, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could CI work with TIMM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/open-mmlab/mmsegmentation/runs/4016108790?check_suite_focus=true
Can work with PyTorch >= 1.6.0, but failed with PyTorch 1.5.1
Please take a look at these two commits:
b75c201
bc7e80e