Skip to content

Commit

Permalink
[Feature] Enable torchvision backbones (open-mmlab#720)
Browse files Browse the repository at this point in the history
* resolve comments

* update changelog

* support torchvision backbones

* add ckpt, changelog and unittest

* fix lint

* fix lint

* fix lint

* Update changelog.md

* Update changelog.md
  • Loading branch information
kennymckormick authored Mar 25, 2021
1 parent 8a79e52 commit dea32a5
Showing 6 changed files with 182 additions and 8 deletions.
8 changes: 5 additions & 3 deletions configs/recognition/tsn/README.md
Original file line number Diff line number Diff line change
@@ -58,11 +58,13 @@ Here, We use [1: 1] to indicate that we combine rgb and flow score with coeffici

It's possible and convenient to use a 3rd-party backbone for TSN under the framework of MMAction2, here we provide some examples for:

- [x] Backbones from MMClassification
- [x] Backbones from [MMClassification](https://github.com/open-mmlab/mmclassification/)
- [x] Backbones from [TorchVision](https://github.com/pytorch/vision/)

| config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json |
| :----------------------------------------------------------: | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json |
| :----------------------------------------------------------- | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[MMCls](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) |
| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[TorchVision](https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) |

### Kinetics-400 Data Benchmark (8-gpus, ResNet50, ImageNet pretrain; 3 segments)

Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
_base_ = [
'../../../_base_/schedules/sgd_100e.py',
'../../../_base_/default_runtime.py'
]

# model settings
model = dict(
type='Recognizer2D',
backbone=dict(type='torchvision.densenet161', pretrained=True),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=2208,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01),
# model training and testing settings
train_cfg=None,
test_cfg=dict(average_clips=None))

# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train_320p'
data_root_val = 'data/kinetics400/rawframes_val_320p'
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes_320p.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=3,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=12,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))

# runtime settings
work_dir = './work_dirs/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/'
optimizer = dict(
type='SGD',
lr=0.00375, # this lr is used for 8 gpus
momentum=0.9,
weight_decay=0.0001)
4 changes: 3 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@

- Support LFB ([#553](https://github.com/open-mmlab/mmaction2/pull/553))
- Support using backbones from MMCls for TSN ([#679](https://github.com/open-mmlab/mmaction2/pull/679))
- Support using backbones from TorchVision for TSN ([#720](https://github.com/open-mmlab/mmaction2/pull/720))
- Support Mixup and Cutmix for recognizers [#681](https://github.com/open-mmlab/mmaction2/pull/681)

**Improvements**
@@ -23,7 +24,8 @@

- Add LFB for AVA2.1 ([#553](https://github.com/open-mmlab/mmaction2/pull/553))
- Add slowonly_nl_embedded_gaussian_r50_4x16x1_150e_kinetics400_rgb ([#690](https://github.com/open-mmlab/mmaction2/pull/690))
- Add TSN with ResNeXt-101-32x4d backbone ([#679](https://github.com/open-mmlab/mmaction2/pull/679))
- Add TSN with ResNeXt-101-32x4d backbone as an example for using MMCls backbones ([#679](https://github.com/open-mmlab/mmaction2/pull/679))
- Add TSN with Densenet161 backbone as an example for using TorchVision backbones ([#720](https://github.com/open-mmlab/mmaction2/pull/720))
- Add slowonly_nl_embedded_gaussian_r50_8x8x1_150e_kinetics400_rgb ([#704](https://github.com/open-mmlab/mmaction2/pull/704))
- Add slowonly_nl_kinetics_pretrained_r50_4x16x1(8x8x1)_20e_ava_rgb ([#730](https://github.com/open-mmlab/mmaction2/pull/730))

37 changes: 34 additions & 3 deletions mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict

@@ -33,14 +34,30 @@ def __init__(self,
train_cfg=None,
test_cfg=None):
super().__init__()
# The backbones in mmcls can be used by TSN
# record the source of the backbone
self.backbone_from = 'mmaction2'

if backbone['type'].startswith('mmcls.'):
try:
import mmcls.models.builder as mmcls_builder
except (ImportError, ModuleNotFoundError):
raise ImportError('Please install mmcls to use this backbone.')
backbone['type'] = backbone['type'][6:]
self.backbone = mmcls_builder.build_backbone(backbone)
self.backbone_from = 'mmcls'
elif backbone['type'].startswith('torchvision.'):
try:
import torchvision.models
except (ImportError, ModuleNotFoundError):
raise ImportError('Please install torchvision to use this '
'backbone.')
backbone_type = backbone.pop('type')[12:]
self.backbone = torchvision.models.__dict__[backbone_type](
**backbone)
# disable the classifier
self.backbone.classifier = nn.Identity()
self.backbone.fc = nn.Identity()
self.backbone_from = 'torchvision'
else:
self.backbone = builder.build_backbone(backbone)

@@ -76,7 +93,17 @@ def __init__(self,

def init_weights(self):
"""Initialize the model network weights."""
self.backbone.init_weights()
if self.backbone_from in ['mmcls', 'mmaction2']:
self.backbone.init_weights()
elif self.backbone_from == 'torchvision':
warnings.warn('We do not initialize weights for backbones in '
'torchvision, since the weights for backbones in '
'torchvision are initialized in their __init__ '
'functions. ')
else:
raise NotImplementedError('Unsupported backbone source '
f'{self.backbone_from}!')

self.cls_head.init_weights()
if hasattr(self, 'neck'):
self.neck.init_weights()
@@ -91,7 +118,11 @@ def extract_feat(self, imgs):
Returns:
torch.tensor: The extracted features.
"""
x = self.backbone(imgs)
if (hasattr(self.backbone, 'features')
and self.backbone_from == 'torchvision'):
x = self.backbone.features(imgs)
else:
x = self.backbone(imgs)
return x

def average_clip(self, cls_score, num_segs=1):
19 changes: 18 additions & 1 deletion mmaction/models/recognizers/recognizer2d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn

from ..registry import RECOGNIZERS
from .base import BaseRecognizer
@@ -17,6 +18,14 @@ def forward_train(self, imgs, labels, **kwargs):
losses = dict()

x = self.extract_feat(imgs)

if self.backbone_from == 'torchvision':
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))

if hasattr(self, 'neck'):
x = [
each.reshape((-1, num_segs) +
@@ -43,6 +52,14 @@ def _do_test(self, imgs):
num_segs = imgs.shape[0] // batches

x = self.extract_feat(imgs)

if self.backbone_from == 'torchvision':
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))

if hasattr(self, 'neck'):
x = [
each.reshape((-1, num_segs) +
@@ -110,7 +127,7 @@ def forward_test(self, imgs):
"""Defines the computation performed at every call when evaluation and
testing."""
if self.test_cfg.get('fcn_test', False):
# If specified, spatially fully-convolutional testing is performed
# If specified, spatially fully-convolutional testing is performed
return self._do_fcn_test(imgs).cpu().numpy()
return self._do_test(imgs).cpu().numpy()

22 changes: 22 additions & 0 deletions tests/test_models/test_recognizers/test_recognizer2d.py
Original file line number Diff line number Diff line change
@@ -69,6 +69,28 @@ def test_tsn():
losses = recognizer(imgs, gt_labels)
assert isinstance(losses, dict)

# test torchvision backbones
tv_backbone = dict(type='torchvision.densenet161', pretrained=True)
config.model['backbone'] = tv_backbone
config.model['cls_head']['in_channels'] = 2208

recognizer = build_recognizer(config.model)

input_shape = (1, 3, 3, 32, 32)
demo_inputs = generate_recognizer_demo_inputs(input_shape)

imgs = demo_inputs['imgs']
gt_labels = demo_inputs['gt_labels']

losses = recognizer(imgs, gt_labels)
assert isinstance(losses, dict)

# Test forward test
with torch.no_grad():
img_list = [img[None, :] for img in imgs]
for one_img in img_list:
recognizer(one_img, None, return_loss=False)


def test_tsm():
config = get_recognizer_cfg('tsm/tsm_r50_1x1x8_50e_kinetics400_rgb.py')

0 comments on commit dea32a5

Please sign in to comment.