Skip to content

Commit

Permalink
Merge pull request #1 from open-mmlab/master
Browse files Browse the repository at this point in the history
upstream
  • Loading branch information
AllentDan authored May 25, 2021
2 parents 97770c3 + e6406c1 commit cfed165
Show file tree
Hide file tree
Showing 24 changed files with 778 additions and 17 deletions.
1 change: 1 addition & 0 deletions configs/restorers/basicvsr/basicvsr_reds4.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,4 @@
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = True
1 change: 1 addition & 0 deletions configs/restorers/basicvsr/basicvsr_vimeo90k_bd.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = True
1 change: 1 addition & 0 deletions configs/restorers/basicvsr/basicvsr_vimeo90k_bi.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = True
1 change: 1 addition & 0 deletions configs/restorers/iconvsr/iconvsr_reds4.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = True
1 change: 1 addition & 0 deletions configs/restorers/iconvsr/iconvsr_vimeo90k_bd.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,4 @@
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = True
1 change: 1 addition & 0 deletions configs/restorers/iconvsr/iconvsr_vimeo90k_bi.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,4 @@
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = True
7 changes: 7 additions & 0 deletions demo/restoration_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os

import mmcv
import torch
Expand All @@ -23,6 +24,12 @@ def parse_args():
def main():
args = parse_args()

if not os.path.isfile(args.img_path):
raise ValueError('It seems that you did not input a valid '
'"image_path". Please double check your input, or '
'you may want to use "restoration_video_demo.py" '
'for video restoration.')

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))

Expand Down
47 changes: 47 additions & 0 deletions demo/restoration_video_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import argparse

import mmcv
import torch

from mmedit.apis import init_model, restoration_video_inference
from mmedit.core import tensor2img


def parse_args():
parser = argparse.ArgumentParser(description='Restoration demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('input_dir', help='directory of the input video')
parser.add_argument('output_dir', help='directory of the output video')
parser.add_argument(
'--filename_tmpl',
default='{:08d}.png',
help='template of the file names')
parser.add_argument(
'--window_size',
type=int,
default=0,
help='window size if sliding-window framework is used')
parser.add_argument('--device', type=int, default=0, help='CUDA device id')
args = parser.parse_args()
return args


def main():
args = parse_args()

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))

output = restoration_video_inference(model, args.input_dir,
args.window_size, args.filename_tmpl)
for i in range(0, output.size(1)):
output_i = output[:, i, :, :, :]
output_i = tensor2img(output_i)
save_path_i = f'{args.output_dir}/{i:08d}.png'

mmcv.imwrite(output_i, save_path_i)


if __name__ == '__main__':
main()
24 changes: 22 additions & 2 deletions docs/demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ python demo/matting_demo.py configs/mattors/dim/dim_stage3_v16_pln_1x1_1000k_com

The predicted alpha matte will be save in `tests/data/pred/GT05.png`.

#### Restoration
#### Restoration (Image)

You can use the following commands to test an image for restoration.

Expand All @@ -47,8 +47,28 @@ If `--imshow` is specified, the demo will also show image with opencv. Examples:
```shell
python demo/restoration_demo.py configs/restorer/esrgan/esrgan_x4c64b23g32_1x16_400k_div2k.py work_dirs/esrgan_x4c64b23g32_1x16_400k_div2k/latest.pth tests/data/lq/baboon_x4.png demo/demo_out_baboon.png
```
#### Restoration (Video)

The restored image will be save in `demo/demo_out_baboon.png`.
You can use the following commands to test a video for restoration.

```shell
python demo/restoration_video_demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} ${INPUT_DIR} ${OUTPUT_DIR} [--window_size=$WINDOW_SIZE] [--device ${GPU_ID}]
```

It suppots both the sliding-window framework and the recurrent framework. Examples:


EDVR:
```shell
python demo/restoration_video_demo.py ./configs/restorers/edvr/edvrm_wotsa_x4_g8_600k_reds.py https://download.openmmlab.com/mmediting/restorers/edvr/edvrm_wotsa_x4_8x4_600k_reds_20200522-0570e567.pth data/Vid4/BIx4/calendar/ ./output --window_size=5
```

BasicVSR:
```shell
python demo/restoration_video_demo.py ./configs/restorers/basicvsr/basicvsr_reds4.py https://download.openmmlab.com/mmediting/restorers/basicvsr/basicvsr_reds4_20120409-0e599677.pth data/Vid4/BIx4/calendar/ ./output
```

The restored video will be save in `output/`.

#### Generation

Expand Down
3 changes: 2 additions & 1 deletion mmedit/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from .inpainting_inference import inpainting_inference
from .matting_inference import init_model, matting_inference
from .restoration_inference import restoration_inference
from .restoration_video_inference import restoration_video_inference
from .test import multi_gpu_test, single_gpu_test
from .train import set_random_seed, train_model

__all__ = [
'train_model', 'set_random_seed', 'init_model', 'matting_inference',
'inpainting_inference', 'restoration_inference', 'generation_inference',
'multi_gpu_test', 'single_gpu_test'
'multi_gpu_test', 'single_gpu_test', 'restoration_video_inference'
]
80 changes: 80 additions & 0 deletions mmedit/apis/restoration_video_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import glob

import torch
from mmcv.parallel import collate, scatter

from mmedit.datasets.pipelines import Compose


def pad_sequence(data, window_size):
padding = window_size // 2

data = torch.cat([
data[:, 1 + padding:1 + 2 * padding].flip(1), data,
data[:, -1 - 2 * padding:-1 - padding].flip(1)
],
dim=1)

return data


def restoration_video_inference(model, img_dir, window_size, filename_tmpl):
"""Inference image with the model.
Args:
model (nn.Module): The loaded model.
img_dir (str): Directory of the input video.
window_size (int): The window size used in sliding-window framework.
This value should be set according to the settings of the network.
A value smaller than 0 means using recurrent framework.
Returns:
Tensor: The predicted restoration result.
"""

device = next(model.parameters()).device # model device

# pipeline
test_pipeline = [
dict(
type='GenerateSegmentIndices',
interval_list=[1],
filename_tmpl=filename_tmpl),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='lq',
channel_order='rgb'),
dict(type='RescaleToZeroOne', keys=['lq']),
dict(type='FramesToTensor', keys=['lq']),
dict(type='Collect', keys=['lq'], meta_keys=['lq_path', 'key'])
]

# build the data pipeline
test_pipeline = Compose(test_pipeline)

# prepare data
sequence_length = len(glob.glob(f'{img_dir}/*'))
key = img_dir.split('/')[-1]
lq_folder = '/'.join(img_dir.split('/')[:-1])
data = dict(
lq_path=lq_folder,
gt_path='',
key=key,
sequence_length=sequence_length)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq']

# forward the model
with torch.no_grad():
if window_size > 0: # sliding window framework
data = pad_sequence(data, window_size)
result = []
for i in range(0, data.size(1) - 2 * window_size):
data_i = data[:, i:i + window_size]
result.append(model(lq=data_i, test_mode=True)['output'])
result = torch.stack(result, dim=1)
else: # recurrent framework
result = model(lq=data, test_mode=True)['output']

return result
3 changes: 2 additions & 1 deletion mmedit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .sr_annotation_dataset import SRAnnotationDataset
from .sr_folder_dataset import SRFolderDataset
from .sr_folder_gt_dataset import SRFolderGTDataset
from .sr_folder_ref_dataset import SRFolderRefDataset
from .sr_lmdb_dataset import SRLmdbDataset
from .sr_reds_dataset import SRREDSDataset
from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset
Expand All @@ -28,5 +29,5 @@
'SRVimeo90KDataset', 'BaseGenerationDataset', 'GenerationPairedDataset',
'GenerationUnpairedDataset', 'SRVid4Dataset', 'SRFolderGTDataset',
'SRREDSMultipleGTDataset', 'SRVimeo90KMultipleGTDataset',
'SRTestMultipleGTDataset'
'SRTestMultipleGTDataset', 'SRFolderRefDataset'
]
8 changes: 5 additions & 3 deletions mmedit/datasets/pipelines/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,10 +927,12 @@ class GenerateSegmentIndices:
interval_list (list[int]): Interval list for temporal augmentation.
It will randomly pick an interval from interval_list and sample
frame index with the interval.
filename_tmpl (str): Template for file name. Default: '{:08d}.png'.
"""

def __init__(self, interval_list):
def __init__(self, interval_list, filename_tmpl='{:08d}.png'):
self.interval_list = interval_list
self.filename_tmpl = filename_tmpl

def __call__(self, results):
"""Call function.
Expand Down Expand Up @@ -964,11 +966,11 @@ def __call__(self, results):
lq_path_root = results['lq_path']
gt_path_root = results['gt_path']
lq_path = [
osp.join(lq_path_root, clip_name, f'{v:08d}.png')
osp.join(lq_path_root, clip_name, self.filename_tmpl.format(v))
for v in neighbor_list
]
gt_path = [
osp.join(gt_path_root, clip_name, f'{v:08d}.png')
osp.join(gt_path_root, clip_name, self.filename_tmpl.format(v))
for v in neighbor_list
]

Expand Down
124 changes: 124 additions & 0 deletions mmedit/datasets/sr_folder_ref_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os.path as osp

from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS


@DATASETS.register_module()
class SRFolderRefDataset(BaseSRDataset):
"""
General paired image folder dataset for reference-based image restoration.
The dataset loads ref (reference) image pairs
Must contain: ref (reference)
Optional: GT (Ground-Truth), LQ (Low Quality), or both
Cannot only contain ref.
Applies specified transforms and finally returns a dict containing paired
data and other information.
This is the "folder mode", which needs to specify the ref folder path and
gt folder path, each folder containing the corresponding images.
Image lists will be generated automatically. You can also specify the
filename template to match the image pairs.
For example, we have three folders with the following structures:
::
data_root
├── ref
│ ├── 0001.png
│ ├── 0002.png
├── gt
│ ├── 0001.png
│ ├── 0002.png
├── lq
│ ├── 0001_x4.png
│ ├── 0002_x4.png
then, you need to set:
.. code-block:: python
ref_folder = 'data_root/ref'
gt_folder = 'data_root/gt'
lq_folder = 'data_root/lq'
filename_tmpl_gt='{}'
filename_tmpl_lq='{}_x4'
Args:
pipeline (List[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
ref_folder (str | :obj:`Path`): Path to a ref folder.
gt_folder (str | :obj:`Path` | None): Path to a gt folder.
Default: None.
lq_folder (str | :obj:`Path` | None): Path to a gt folder.
Default: None.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
filename_tmpl_gt (str): Template for gt filename. Note that the
template excludes the file extension. Default: '{}'.
filename_tmpl_lq (str): Template for lq filename. Note that the
template excludes the file extension. Default: '{}'.
"""

def __init__(self,
pipeline,
scale,
ref_folder,
gt_folder=None,
lq_folder=None,
test_mode=False,
filename_tmpl_gt='{}',
filename_tmpl_lq='{}'):
super().__init__(pipeline, scale, test_mode)
assert gt_folder or lq_folder, 'At least one of gt_folder and' \
'lq_folder cannot be None.'
self.scale = scale
self.ref_folder = str(ref_folder)
self.gt_folder = str(gt_folder) if gt_folder else None
self.lq_folder = str(lq_folder) if lq_folder else None
self.filename_tmpl_gt = filename_tmpl_gt
self.filename_tmpl_lq = filename_tmpl_lq
self.data_infos = self.load_annotations()

def load_annotations(self):
"""Load annoations for SR dataset.
It loads the ref, LQ and GT image path from folders.
Returns:
dict: Returned dict for ref, LQ and GT pairs.
"""
data_infos = []
ref_paths = self.scan_folder(self.ref_folder)
if self.gt_folder is not None:
gt_paths = self.scan_folder(self.gt_folder)
assert len(ref_paths) == len(gt_paths), (
f'ref and gt datasets have different number of images: '
f'{len(ref_paths)}, {len(gt_paths)}.')
if self.lq_folder is not None:
lq_paths = self.scan_folder(self.lq_folder)
assert len(ref_paths) == len(lq_paths), (
f'ref and lq datasets have different number of images: '
f'{len(ref_paths)}, {len(lq_paths)}.')
for ref_path in ref_paths:
basename, ext = osp.splitext(osp.basename(ref_path))
data_dict = dict(ref_path=ref_path)
if self.gt_folder is not None:
gt_path = osp.join(self.gt_folder,
(f'{self.filename_tmpl_gt.format(basename)}'
f'{ext}'))
assert gt_path in gt_paths, \
f'{gt_path} is not in gt_paths.'
data_dict['gt_path'] = gt_path
if self.lq_folder is not None:
lq_path = osp.join(self.lq_folder,
(f'{self.filename_tmpl_lq.format(basename)}'
f'{ext}'))
assert lq_path in lq_paths, \
f'{lq_path} is not in lq_paths.'
data_dict['lq_path'] = lq_path
data_infos.append(data_dict)
return data_infos
Loading

0 comments on commit cfed165

Please sign in to comment.