Skip to content

Commit 383c531

Browse files
authored
[Feature] Support TAM (open-mmlab#595)
* draft of tam block * draft of TANet * fix linting * add unittest for tam * add TAM to models/__init__.py * add config file and unittest for tanet * fix tam bug * update tam * fix __init__ * rename config * fix tsm_optimizer bug * add sth * fix tanet config * modify workers * update tanet config * fix typo * add recognizer2d unittest * update accroding comment * add readme of tanet * add url
1 parent c35390c commit 383c531

17 files changed

+548
-10
lines changed

configs/_base_/models/tanet_r50.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# model settings
2+
model = dict(
3+
type='Recognizer2D',
4+
backbone=dict(
5+
type='TANet',
6+
pretrained='torchvision://resnet50',
7+
depth=50,
8+
num_segments=8,
9+
tam_cfg=dict()),
10+
cls_head=dict(
11+
type='TSMHead',
12+
num_classes=400,
13+
in_channels=2048,
14+
spatial_type='avg',
15+
consensus=dict(type='AvgConsensus', dim=1),
16+
dropout_ratio=0.5,
17+
init_std=0.001))
18+
train_cfg = None
19+
test_cfg = dict(average_clips='prob')

configs/recognition/tanet/README.md

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# TANet
2+
3+
## Introduction
4+
5+
[ALGORITHM]
6+
7+
```latex
8+
@article{liu2020tam,
9+
title={TAM: Temporal Adaptive Module for Video Recognition},
10+
author={Liu, Zhaoyang and Wang, Limin and Wu, Wayne and Qian, Chen and Lu, Tong},
11+
journal={arXiv preprint arXiv:2005.06803},
12+
year={2020}
13+
}
14+
```
15+
16+
## Model Zoo
17+
18+
### Kinetics-400
19+
20+
|config | resolution | gpus | backbone | pretrain | top1 acc| top5 acc | reference top1 acc | reference top5 acc | inference_time(video/s) | gpu_mem(M)| ckpt | log| json|
21+
|:--|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
22+
|[tanet_r50_dense_1x1x8_100e_kinetics400_rgb](/configs/recognition/tanet/tanet_r50_dense_1x1x8_100e_kinetics400_rgb.py) |short-side 320|8| TANet | ImageNet |76.28 | 92.60 |[76.22](https://github.com/liu-zhy/temporal-adaptive-module/blob/master/scripts/test_tam_kinetics_rgb_8f.sh)|[92.53](https://github.com/liu-zhy/temporal-adaptive-module/blob/master/scripts/test_tam_kinetics_rgb_8f.sh) | x | 7124 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tanet/tanet_r50_dense_1x1x8_100e_kinetics400_rgb/tanet_r50_dense_1x1x8_100e_kinetics400_rgb_20210219-032c8e94.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tanet/tanet_r50_dense_1x1x8_100e_kinetics400_rgb/tanet_r50_dense_1x1x8_100e_kinetics400_rgb_20210219.log)| [json](https://download.openmmlab.com/mmaction/recognition/tanet/tanet_r50_dense_1x1x8_100e_kinetics400_rgb/tanet_r50_dense_1x1x8_100e_kinetics400_rgb_20210219.json)|
23+
24+
Notes:
25+
26+
1. The **gpus** indicates the number of gpu we used to get the checkpoint. It is noteworthy that the configs we provide are used for 8 gpus as default.
27+
According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU,
28+
e.g., lr=0.01 for 8 GPUs x 8 videos/gpu and lr=0.04 for 16 GPUs x 16 videos/gpu.
29+
2. The values in columns named after "reference" are the results got by testing on our dataset, using the checkpoints provided by the author with same model settings. The checkpoints for reference repo can be downloaded [here](https://drive.google.com/drive/folders/1sFfmP3yrfc7IzRshEELOby7-aEoymIFL?usp=sharing).
30+
31+
For more details on data preparation, you can refer to Kinetics400 in [Data Preparation](/docs/data_preparation.md).
32+
33+
## Train
34+
35+
You can use the following command to train a model.
36+
37+
```shell
38+
python tools/train.py ${CONFIG_FILE} [optional arguments]
39+
```
40+
41+
Example: train TANet model on Kinetics-400 dataset in a deterministic option with periodic validation.
42+
43+
```shell
44+
python tools/train.py configs/recognition/tanet/tanet_r50_dense_1x1x8_100e_kinetics400_rgb.py \
45+
--work-dir work_dirs/tanet_r50_dense_1x1x8_100e_kinetics400_rgb \
46+
--validate --seed 0 --deterministic
47+
```
48+
49+
For more details, you can refer to **Training setting** part in [getting_started](/docs/getting_started.md#training-setting).
50+
51+
## Test
52+
53+
You can use the following command to test a model.
54+
55+
```shell
56+
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]
57+
```
58+
59+
Example: test TANet model on Kinetics-400 dataset and dump the result to a json file.
60+
61+
```shell
62+
python tools/test.py configs/recognition/tanet/tanet_r50_dense_1x1x8_100e_kinetics400_rgb.py \
63+
checkpoints/SOME_CHECKPOINT.pth --eval top_k_accuracy mean_class_accuracy \
64+
--out result.json
65+
```
66+
67+
For more details, you can refer to **Test a dataset** part in [getting_started](/docs/getting_started.md#test-a-dataset).
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
_base_ = [
2+
'../../_base_/models/tanet_r50.py', '../../_base_/default_runtime.py'
3+
]
4+
5+
# dataset settings
6+
dataset_type = 'RawframeDataset'
7+
data_root = 'data/kinetics400/rawframes_train'
8+
data_root_val = 'data/kinetics400/rawframes_val'
9+
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt'
10+
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
11+
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
12+
13+
img_norm_cfg = dict(
14+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
15+
16+
train_pipeline = [
17+
dict(type='DenseSampleFrames', clip_len=1, frame_interval=1, num_clips=8),
18+
dict(type='RawFrameDecode'),
19+
dict(type='Resize', scale=(-1, 256)),
20+
dict(
21+
type='MultiScaleCrop',
22+
input_size=224,
23+
scales=(1, 0.875, 0.75, 0.66),
24+
random_crop=False,
25+
max_wh_scale_gap=1,
26+
num_fixed_crops=13),
27+
dict(type='Resize', scale=(224, 224), keep_ratio=False),
28+
dict(type='Flip', flip_ratio=0.5),
29+
dict(type='Normalize', **img_norm_cfg),
30+
dict(type='FormatShape', input_format='NCHW'),
31+
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
32+
dict(type='ToTensor', keys=['imgs', 'label'])
33+
]
34+
val_pipeline = [
35+
dict(
36+
type='DenseSampleFrames',
37+
clip_len=1,
38+
frame_interval=1,
39+
num_clips=8,
40+
test_mode=True),
41+
dict(type='RawFrameDecode'),
42+
dict(type='Resize', scale=(-1, 256)),
43+
dict(type='CenterCrop', crop_size=224),
44+
dict(type='Normalize', **img_norm_cfg),
45+
dict(type='FormatShape', input_format='NCHW'),
46+
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
47+
dict(type='ToTensor', keys=['imgs'])
48+
]
49+
test_pipeline = [
50+
dict(
51+
type='DenseSampleFrames',
52+
clip_len=1,
53+
frame_interval=1,
54+
num_clips=8,
55+
test_mode=True),
56+
dict(type='RawFrameDecode'),
57+
dict(type='Resize', scale=(-1, 256)),
58+
dict(type='ThreeCrop', crop_size=256),
59+
dict(type='Normalize', **img_norm_cfg),
60+
dict(type='FormatShape', input_format='NCHW'),
61+
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
62+
dict(type='ToTensor', keys=['imgs'])
63+
]
64+
data = dict(
65+
videos_per_gpu=8,
66+
workers_per_gpu=4,
67+
test_dataloader=dict(videos_per_gpu=2),
68+
train=dict(
69+
type=dataset_type,
70+
ann_file=ann_file_train,
71+
data_prefix=data_root,
72+
pipeline=train_pipeline),
73+
val=dict(
74+
type=dataset_type,
75+
ann_file=ann_file_val,
76+
data_prefix=data_root_val,
77+
pipeline=val_pipeline),
78+
test=dict(
79+
type=dataset_type,
80+
ann_file=ann_file_test,
81+
data_prefix=data_root_val,
82+
pipeline=test_pipeline))
83+
evaluation = dict(
84+
interval=2, metrics=['top_k_accuracy', 'mean_class_accuracy'])
85+
86+
# optimizer
87+
optimizer = dict(
88+
type='SGD',
89+
constructor='TSMOptimizerConstructor',
90+
paramwise_cfg=dict(fc_lr5=True),
91+
lr=0.01, # this lr is used for 8 gpus
92+
momentum=0.9,
93+
weight_decay=0.0001)
94+
optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2))
95+
# learning policy
96+
lr_config = dict(policy='step', step=[50, 75, 90])
97+
total_epochs = 100
98+
99+
# runtime settings
100+
work_dir = './work_dirs/tanet_r50_dense_1x1x8_100e_kinetics400_rgb/'

mmaction/core/optimizer/tsm_optimizer_constructor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def add_params(self, params, model):
5454
elif isinstance(m, torch.nn.Linear):
5555
m_params = list(m.parameters())
5656
normal_weight.append(m_params[0])
57-
normal_bias.append(m_params[1])
57+
if len(m_params) == 2:
58+
normal_bias.append(m_params[1])
5859
elif isinstance(m,
5960
(_BatchNorm, SyncBatchNorm, torch.nn.GroupNorm)):
6061
for param in list(m.parameters()):

mmaction/models/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from .backbones import (C3D, X3D, MobileNetV2, MobileNetV2TSM, ResNet,
22
ResNet2Plus1d, ResNet3d, ResNet3dCSN, ResNet3dLayer,
33
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetAudio,
4-
ResNetTIN, ResNetTSM)
4+
ResNetTIN, ResNetTSM, TANet)
55
from .builder import (DETECTORS, build_backbone, build_detector, build_head,
66
build_localizer, build_loss, build_model, build_neck,
77
build_recognizer)
8-
from .common import Conv2plus1d, ConvAudio
8+
from .common import TAM, Conv2plus1d, ConvAudio
99
from .heads import (AudioTSNHead, AVARoIHead, BaseHead, BBoxHeadAVA, I3DHead,
1010
SlowFastHead, TPNHead, TSMHead, TSNHead, X3DHead)
1111
from .localizers import BMN, PEM, TEM
@@ -25,10 +25,10 @@
2525
'BaseRecognizer', 'LOSSES', 'CrossEntropyLoss', 'NLLLoss', 'HVULoss',
2626
'ResNetTSM', 'ResNet3dSlowFast', 'SlowFastHead', 'Conv2plus1d',
2727
'ResNet3dSlowOnly', 'BCELossWithLogits', 'LOCALIZERS', 'build_localizer',
28-
'PEM', 'TEM', 'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss',
28+
'PEM', 'TAM', 'TEM', 'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss',
2929
'build_model', 'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN',
3030
'TPN', 'TPNHead', 'build_loss', 'build_neck', 'AudioRecognizer',
3131
'AudioTSNHead', 'X3D', 'X3DHead', 'ResNet3dLayer', 'DETECTORS',
3232
'SingleRoIExtractor3D', 'BBoxHeadAVA', 'ResNetAudio', 'build_detector',
33-
'ConvAudio', 'AVARoIHead', 'MobileNetV2', 'MobileNetV2TSM'
33+
'ConvAudio', 'AVARoIHead', 'MobileNetV2', 'MobileNetV2TSM', 'TANet'
3434
]

mmaction/models/backbones/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from .resnet_audio import ResNetAudio
1111
from .resnet_tin import ResNetTIN
1212
from .resnet_tsm import ResNetTSM
13+
from .tanet import TANet
1314
from .x3d import X3D
1415

1516
__all__ = [
1617
'C3D', 'ResNet', 'ResNet3d', 'ResNetTSM', 'ResNet2Plus1d',
1718
'ResNet3dSlowFast', 'ResNet3dSlowOnly', 'ResNet3dCSN', 'ResNetTIN', 'X3D',
18-
'ResNetAudio', 'ResNet3dLayer', 'MobileNetV2TSM', 'MobileNetV2'
19+
'ResNetAudio', 'ResNet3dLayer', 'MobileNetV2TSM', 'MobileNetV2', 'TANet'
1920
]

mmaction/models/backbones/tanet.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from copy import deepcopy
2+
3+
import torch.nn as nn
4+
from torch.utils import checkpoint as cp
5+
6+
from ..common import TAM
7+
from ..registry import BACKBONES
8+
from .resnet import Bottleneck, ResNet
9+
10+
11+
class TABlock(nn.Module):
12+
"""Temporal Adaptive Block (TA-Block) for TANet.
13+
14+
This block is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO
15+
RECOGNITION <https://arxiv.org/pdf/2005.06803>`_
16+
17+
The temporal adaptive module (TAM) is embedded into ResNet-Block
18+
after the first Conv2D, which turns the vanilla ResNet-Block
19+
into TA-Block.
20+
21+
Args:
22+
block (nn.Module): Residual blocks to be substituted.
23+
num_segments (int): Number of frame segments.
24+
tam_cfg (dict): Config for temporal adaptive module (TAM).
25+
Default: dict().
26+
"""
27+
28+
def __init__(self, block, num_segments, tam_cfg=dict()):
29+
super().__init__()
30+
self.tam_cfg = deepcopy(tam_cfg)
31+
self.block = block
32+
self.num_segments = num_segments
33+
self.tam = TAM(
34+
in_channels=block.conv1.out_channels,
35+
num_segments=num_segments,
36+
**self.tam_cfg)
37+
38+
if not isinstance(self.block, Bottleneck):
39+
raise NotImplementedError('TA-Blocks have not been fully '
40+
'implemented except the pattern based '
41+
'on Bottleneck block.')
42+
43+
def forward(self, x):
44+
if isinstance(self.block, Bottleneck):
45+
46+
def _inner_forward(x):
47+
"""Forward wrapper for utilizing checkpoint."""
48+
identity = x
49+
50+
out = self.block.conv1(x)
51+
out = self.tam(out)
52+
out = self.block.conv2(out)
53+
out = self.block.conv3(out)
54+
55+
if self.block.downsample is not None:
56+
identity = self.block.downsample(x)
57+
58+
out = out + identity
59+
60+
return out
61+
62+
if self.block.with_cp and x.requires_grad:
63+
out = cp.checkpoint(_inner_forward, x)
64+
else:
65+
out = _inner_forward(x)
66+
67+
out = self.block.relu(out)
68+
69+
return out
70+
71+
72+
@BACKBONES.register_module()
73+
class TANet(ResNet):
74+
"""Temporal Adaptive Network (TANet) backbone.
75+
76+
This backbone is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO
77+
RECOGNITION <https://arxiv.org/pdf/2005.06803>`_
78+
79+
Embedding the temporal adaptive module (TAM) into ResNet to
80+
instantiate TANet.
81+
82+
Args:
83+
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
84+
num_segments (int): Number of frame segments.
85+
tam_cfg (dict | None): Config for temporal adaptive module (TAM).
86+
Default: dict().
87+
**kwargs (keyword arguments, optional): Arguments for ResNet except
88+
```depth```.
89+
"""
90+
91+
def __init__(self, depth, num_segments, tam_cfg=dict(), **kwargs):
92+
super().__init__(depth, **kwargs)
93+
assert num_segments >= 3
94+
self.num_segments = num_segments
95+
self.tam_cfg = deepcopy(tam_cfg)
96+
97+
def init_weights(self):
98+
super().init_weights()
99+
self.make_tam_modeling()
100+
101+
def make_tam_modeling(self):
102+
"""Replace ResNet-Block with TA-Block."""
103+
104+
def make_tam_block(stage, num_segments, tam_cfg=dict()):
105+
blocks = list(stage.children())
106+
for i, block in enumerate(blocks):
107+
blocks[i] = TABlock(block, num_segments, deepcopy(tam_cfg))
108+
return nn.Sequential(*blocks)
109+
110+
for i in range(self.num_stages):
111+
layer_name = f'layer{i + 1}'
112+
res_layer = getattr(self, layer_name)
113+
setattr(self, layer_name,
114+
make_tam_block(res_layer, self.num_segments, self.tam_cfg))

mmaction/models/common/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .conv2plus1d import Conv2plus1d
22
from .conv_audio import ConvAudio
3+
from .tam import TAM
34

4-
__all__ = ['Conv2plus1d', 'ConvAudio']
5+
__all__ = ['Conv2plus1d', 'ConvAudio', 'TAM']

0 commit comments

Comments
 (0)