Skip to content

Commit

Permalink
[Feature] Add TDAN config and models (#347)
Browse files Browse the repository at this point in the history
* Add TDAN config and models

* Add training and test descriptions

* Update readme descrptions

* Update README
  • Loading branch information
ckkelvinchan authored Jun 1, 2021
1 parent 290c75e commit 085a277
Show file tree
Hide file tree
Showing 7 changed files with 577 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Supported algorithms:
- [x] [RDN](configs/restorers/rdn/README.md) (CVPR'2018)
- [x] [SRCNN](configs/restorers/srcnn/README.md) (TPAMI'2015)
- [x] [SRResNet&SRGAN](configs/restorers/srresnet_srgan/README.md) (CVPR'2016)
- [x] [TDAN](configs/restorers/tdan/README.md) (CVPR'2020)
- [x] [TOF](configs/restorers/tof/README.md) (IJCV'2019)
- [x] [TTSR](configs/restorers/ttsr/README.md) (CVPR'2020)

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ MMEditing 是基于 PyTorch 的图像&视频编辑开源工具箱。是 [OpenMML
- [x] [RDN](configs/restorers/rdn/README.md) (CVPR'2018)
- [x] [SRCNN](configs/restorers/srcnn/README.md) (TPAMI'2015)
- [x] [SRResNet&SRGAN](configs/restorers/srresnet_srgan/README.md) (CVPR'2016)
- [x] [TDAN](configs/restorers/tdan/README.md) (CVPR'2020)
- [x] [TOF](configs/restorers/tof/README.md) (IJCV'2019)
- [x] [TTSR](configs/restorers/ttsr/README.md) (CVPR'2020)

Expand Down
67 changes: 67 additions & 0 deletions configs/restorers/tdan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# TDAN: Temporally-Deformable Alignment Network for Video Super-Resolution

## Introduction

<!-- [ALGORITHM] -->

```bibtex
@InProceedings{tian2020tdan,
title={TDAN: Temporally-Deformable Alignment Network for Video Super-Resolution},
author={Tian, Yapeng and Zhang, Yulun and Fu, Yun and Xu, Chenliang},
booktitle = {Proceedings of the IEEE conference on Computer Vision and Pattern Recognition},
year = {2020}
}
```

## Results and Models

Evaluated on Y-channel. 8 pixels in each border are cropped before evaluation.

The metrics are `PSNR / SSIM`.

| Method | Vid4 (BIx4) | SPMCS-30 (BIx4) | Vid4 (BDx4) | SPMCS-30 (BDx4) | Download |
|:-------------------------------------------------------------------:|:---------------:|:---------------:|:---------------:|:---------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| [tdan_vimeo90k_bix4](/configs/restorers/tdan/tdan_vimeo90k_bix4.py) | **26.49/0.792** | **30.42/0.856** | 25.93/0.772 | 29.69/0.842 | [model](https://download.openmmlab.com/mmediting/restorers/tdan/tdan_vimeo90k_bix4_20210528-739979d9.pth) \| [log](https://download.openmmlab.com/mmediting/restorers/tdan/tdan_vimeo90k_bix4_20210528_135616.log.json) |
| [tdan_vimeo90k_bdx4](/configs/restorers/tdan/tdan_vimeo90k_bdx4.py) | 25.80/0.784 | 29.56/0.851 | **26.87/0.815** | **30.77/0.868** | [model](https://download.openmmlab.com/mmediting/restorers/tdan/tdan_vimeo90k_bdx4_20210528-c53ab844.pth) \| [log](https://download.openmmlab.com/mmediting/restorers/tdan/tdan_vimeo90k_bdx4_20210528_122401.log.json) |


## Train

You can use the following command to train a model.

```shell
./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]
```

TDAN is trained with two stages.

**Stage 1**: Train with a larger learning rate (1e-4)


```shell
./tools/dist_train.sh configs/restorers/tdan/tdan_vimeo90k_bix4_lr1e-4_400k.py 8
```

**Stage 2**: Fine-tune with a smaller learning rate (5e-5)

```shell
./tools/dist_train.sh configs/restorers/tdan/tdan_vimeo90k_bix4_ft_lr5e-5_400k.py 8
```

For more details, you can refer to **Train a model** part in [getting_started](/docs/getting_started.md#train-a-model).

## Test

You can use the following command to test a model.

```shell
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--save-path ${IMAGE_SAVE_PATH}]
```

Example: Test TDAN on SPMCS-30 using Bicubic downsampling.

```shell
python tools/test.py configs/restorers/tdan/tdan_vimeo90k_bix4_ft_lr5e-5_400k.py checkpoints/SOME_CHECKPOINT.pth --save_path outputs/
```

For more details, you can refer to **Inference with pretrained models** part in [getting_started](/docs/getting_started.md#inference-with-pretrained-models).
127 changes: 127 additions & 0 deletions configs/restorers/tdan/tdan_vimeo90k_bdx4_ft_lr5e-5_800k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
exp_name = 'tdan_vimeo90k_bdx4_ft_lr5e-5_800k'

# model settings
model = dict(
type='TDAN',
generator=dict(type='TDANNet'),
pixel_loss=dict(type='MSELoss', loss_weight=1.0, reduction='mean'),
lq_pixel_loss=dict(type='MSELoss', loss_weight=0.01, reduction='mean'))
# model training and testing settings
train_cfg = None
test_cfg = dict(metrics=['PSNR', 'SSIM'], crop_border=8, convert_to='y')

# dataset settings
train_dataset_type = 'SRVimeo90KDataset'
val_dataset_type = 'SRVid4Dataset'

train_pipeline = [
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='lq',
channel_order='rgb'),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='gt',
channel_order='rgb'),
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
dict(
type='Normalize',
keys=['lq', 'gt'],
mean=[0.5, 0.5, 0.5],
std=[1, 1, 1]),
dict(type='PairedRandomCrop', gt_patch_size=192),
dict(
type='Flip', keys=['lq', 'gt'], flip_ratio=0.5,
direction='horizontal'),
dict(type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, direction='vertical'),
dict(type='RandomTransposeHW', keys=['lq', 'gt'], transpose_ratio=0.5),
dict(type='FramesToTensor', keys=['lq', 'gt']),
dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'gt_path'])
]

val_pipeline = [
dict(type='GenerateFrameIndiceswithPadding', padding='reflection'),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='lq',
channel_order='rgb'),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='gt',
channel_order='rgb'),
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
dict(
type='Normalize',
keys=['lq', 'gt'],
mean=[0.5, 0.5, 0.5],
std=[1, 1, 1]),
dict(type='FramesToTensor', keys=['lq', 'gt']),
dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'gt_path'])
]

data = dict(
workers_per_gpu=8,
train_dataloader=dict(samples_per_gpu=16, drop_last=True), # 8 gpus
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='RepeatDataset',
times=1000,
dataset=dict(
type=train_dataset_type,
lq_folder='data/Vimeo-90K/BDx4',
gt_folder='data/Vimeo-90K/GT',
ann_file='data/Vimeo-90K/meta_info_Vimeo90K_train_GT.txt',
num_input_frames=5,
pipeline=train_pipeline,
scale=4,
test_mode=False)),
val=dict(
type=val_dataset_type,
lq_folder='data/Vid4/BDx4',
gt_folder='data/Vid4/GT',
pipeline=val_pipeline,
ann_file='data/Vid4/meta_info_Vid4_GT.txt',
scale=4,
num_input_frames=5,
test_mode=True),
test=dict(
type=val_dataset_type,
lq_folder='data/SPMCS/BDx4',
gt_folder='data/SPMCS/GT',
pipeline=val_pipeline,
ann_file='data/SPMCS/meta_info_SPMCS_GT.txt',
scale=4,
num_input_frames=5,
test_mode=True),
)

# optimizer
optimizers = dict(generator=dict(type='Adam', lr=5e-5))

# learning policy
total_iters = 800000
lr_config = dict(policy='Step', by_epoch=False, step=[800000], gamma=0.5)

checkpoint_config = dict(interval=50000, save_optimizer=True, by_epoch=False)
# remove gpu_collect=True in non distributed training
evaluation = dict(interval=50000, save_image=False, gpu_collect=True)
log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook'),
])
visual_config = None

# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = f'./work_dirs/{exp_name}'
load_from = './experiments/tdan_vimeo90k_bdx4_lr1e-4_400k/iter_400000.pth'
resume_from = None
workflow = [('train', 1)]
127 changes: 127 additions & 0 deletions configs/restorers/tdan/tdan_vimeo90k_bdx4_lr1e-4_400k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
exp_name = 'tdan_vimeo90k_bdx4_lr1e-4_400k'

# model settings
model = dict(
type='TDAN',
generator=dict(type='TDANNet'),
pixel_loss=dict(type='MSELoss', loss_weight=1.0, reduction='mean'),
lq_pixel_loss=dict(type='MSELoss', loss_weight=0.01, reduction='mean'))
# model training and testing settings
train_cfg = None
test_cfg = dict(metrics=['PSNR', 'SSIM'], crop_border=8, convert_to='y')

# dataset settings
train_dataset_type = 'SRVimeo90KDataset'
val_dataset_type = 'SRVid4Dataset'

train_pipeline = [
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='lq',
channel_order='rgb'),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='gt',
channel_order='rgb'),
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
dict(
type='Normalize',
keys=['lq', 'gt'],
mean=[0.5, 0.5, 0.5],
std=[1, 1, 1]),
dict(type='PairedRandomCrop', gt_patch_size=192),
dict(
type='Flip', keys=['lq', 'gt'], flip_ratio=0.5,
direction='horizontal'),
dict(type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, direction='vertical'),
dict(type='RandomTransposeHW', keys=['lq', 'gt'], transpose_ratio=0.5),
dict(type='FramesToTensor', keys=['lq', 'gt']),
dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'gt_path'])
]

val_pipeline = [
dict(type='GenerateFrameIndiceswithPadding', padding='reflection'),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='lq',
channel_order='rgb'),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='gt',
channel_order='rgb'),
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
dict(
type='Normalize',
keys=['lq', 'gt'],
mean=[0.5, 0.5, 0.5],
std=[1, 1, 1]),
dict(type='FramesToTensor', keys=['lq', 'gt']),
dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'gt_path'])
]

data = dict(
workers_per_gpu=8,
train_dataloader=dict(samples_per_gpu=16, drop_last=True), # 8 gpus
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='RepeatDataset',
times=1000,
dataset=dict(
type=train_dataset_type,
lq_folder='data/Vimeo-90K/BDx4',
gt_folder='data/Vimeo-90K/GT',
ann_file='data/Vimeo-90K/meta_info_Vimeo90K_train_GT.txt',
num_input_frames=5,
pipeline=train_pipeline,
scale=4,
test_mode=False)),
val=dict(
type=val_dataset_type,
lq_folder='data/Vid4/BDx4',
gt_folder='data/Vid4/GT',
pipeline=val_pipeline,
ann_file='data/Vid4/meta_info_Vid4_GT.txt',
scale=4,
num_input_frames=5,
test_mode=True),
test=dict(
type=val_dataset_type,
lq_folder='data/SPMCS/BDx4',
gt_folder='data/SPMCS/GT',
pipeline=val_pipeline,
ann_file='data/SPMCS/meta_info_SPMCS_GT.txt',
scale=4,
num_input_frames=5,
test_mode=True),
)

# optimizer
optimizers = dict(generator=dict(type='Adam', lr=1e-4, weight_decay=1e-6))

# learning policy
total_iters = 800000
lr_config = dict(policy='Step', by_epoch=False, step=[800000], gamma=0.5)

checkpoint_config = dict(interval=50000, save_optimizer=True, by_epoch=False)
# remove gpu_collect=True in non distributed training
evaluation = dict(interval=50000, save_image=False, gpu_collect=True)
log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook'),
])
visual_config = None

# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = f'./work_dirs/{exp_name}'
load_from = None
resume_from = None
workflow = [('train', 1)]
Loading

0 comments on commit 085a277

Please sign in to comment.