-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from open-mmlab/master
upstream
- Loading branch information
Showing
24 changed files
with
778 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,3 +138,4 @@ | |
load_from = None | ||
resume_from = None | ||
workflow = [('train', 1)] | ||
find_unused_parameters = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,3 +155,4 @@ | |
load_from = None | ||
resume_from = None | ||
workflow = [('train', 1)] | ||
find_unused_parameters = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,3 +155,4 @@ | |
load_from = None | ||
resume_from = None | ||
workflow = [('train', 1)] | ||
find_unused_parameters = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,3 +140,4 @@ | |
load_from = None | ||
resume_from = None | ||
workflow = [('train', 1)] | ||
find_unused_parameters = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,3 +159,4 @@ | |
load_from = None | ||
resume_from = None | ||
workflow = [('train', 1)] | ||
find_unused_parameters = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,3 +159,4 @@ | |
load_from = None | ||
resume_from = None | ||
workflow = [('train', 1)] | ||
find_unused_parameters = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import argparse | ||
|
||
import mmcv | ||
import torch | ||
|
||
from mmedit.apis import init_model, restoration_video_inference | ||
from mmedit.core import tensor2img | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Restoration demo') | ||
parser.add_argument('config', help='test config file path') | ||
parser.add_argument('checkpoint', help='checkpoint file') | ||
parser.add_argument('input_dir', help='directory of the input video') | ||
parser.add_argument('output_dir', help='directory of the output video') | ||
parser.add_argument( | ||
'--filename_tmpl', | ||
default='{:08d}.png', | ||
help='template of the file names') | ||
parser.add_argument( | ||
'--window_size', | ||
type=int, | ||
default=0, | ||
help='window size if sliding-window framework is used') | ||
parser.add_argument('--device', type=int, default=0, help='CUDA device id') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
model = init_model( | ||
args.config, args.checkpoint, device=torch.device('cuda', args.device)) | ||
|
||
output = restoration_video_inference(model, args.input_dir, | ||
args.window_size, args.filename_tmpl) | ||
for i in range(0, output.size(1)): | ||
output_i = output[:, i, :, :, :] | ||
output_i = tensor2img(output_i) | ||
save_path_i = f'{args.output_dir}/{i:08d}.png' | ||
|
||
mmcv.imwrite(output_i, save_path_i) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import glob | ||
|
||
import torch | ||
from mmcv.parallel import collate, scatter | ||
|
||
from mmedit.datasets.pipelines import Compose | ||
|
||
|
||
def pad_sequence(data, window_size): | ||
padding = window_size // 2 | ||
|
||
data = torch.cat([ | ||
data[:, 1 + padding:1 + 2 * padding].flip(1), data, | ||
data[:, -1 - 2 * padding:-1 - padding].flip(1) | ||
], | ||
dim=1) | ||
|
||
return data | ||
|
||
|
||
def restoration_video_inference(model, img_dir, window_size, filename_tmpl): | ||
"""Inference image with the model. | ||
Args: | ||
model (nn.Module): The loaded model. | ||
img_dir (str): Directory of the input video. | ||
window_size (int): The window size used in sliding-window framework. | ||
This value should be set according to the settings of the network. | ||
A value smaller than 0 means using recurrent framework. | ||
Returns: | ||
Tensor: The predicted restoration result. | ||
""" | ||
|
||
device = next(model.parameters()).device # model device | ||
|
||
# pipeline | ||
test_pipeline = [ | ||
dict( | ||
type='GenerateSegmentIndices', | ||
interval_list=[1], | ||
filename_tmpl=filename_tmpl), | ||
dict( | ||
type='LoadImageFromFileList', | ||
io_backend='disk', | ||
key='lq', | ||
channel_order='rgb'), | ||
dict(type='RescaleToZeroOne', keys=['lq']), | ||
dict(type='FramesToTensor', keys=['lq']), | ||
dict(type='Collect', keys=['lq'], meta_keys=['lq_path', 'key']) | ||
] | ||
|
||
# build the data pipeline | ||
test_pipeline = Compose(test_pipeline) | ||
|
||
# prepare data | ||
sequence_length = len(glob.glob(f'{img_dir}/*')) | ||
key = img_dir.split('/')[-1] | ||
lq_folder = '/'.join(img_dir.split('/')[:-1]) | ||
data = dict( | ||
lq_path=lq_folder, | ||
gt_path='', | ||
key=key, | ||
sequence_length=sequence_length) | ||
data = test_pipeline(data) | ||
data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq'] | ||
|
||
# forward the model | ||
with torch.no_grad(): | ||
if window_size > 0: # sliding window framework | ||
data = pad_sequence(data, window_size) | ||
result = [] | ||
for i in range(0, data.size(1) - 2 * window_size): | ||
data_i = data[:, i:i + window_size] | ||
result.append(model(lq=data_i, test_mode=True)['output']) | ||
result = torch.stack(result, dim=1) | ||
else: # recurrent framework | ||
result = model(lq=data, test_mode=True)['output'] | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import os.path as osp | ||
|
||
from .base_sr_dataset import BaseSRDataset | ||
from .registry import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class SRFolderRefDataset(BaseSRDataset): | ||
""" | ||
General paired image folder dataset for reference-based image restoration. | ||
The dataset loads ref (reference) image pairs | ||
Must contain: ref (reference) | ||
Optional: GT (Ground-Truth), LQ (Low Quality), or both | ||
Cannot only contain ref. | ||
Applies specified transforms and finally returns a dict containing paired | ||
data and other information. | ||
This is the "folder mode", which needs to specify the ref folder path and | ||
gt folder path, each folder containing the corresponding images. | ||
Image lists will be generated automatically. You can also specify the | ||
filename template to match the image pairs. | ||
For example, we have three folders with the following structures: | ||
:: | ||
data_root | ||
├── ref | ||
│ ├── 0001.png | ||
│ ├── 0002.png | ||
├── gt | ||
│ ├── 0001.png | ||
│ ├── 0002.png | ||
├── lq | ||
│ ├── 0001_x4.png | ||
│ ├── 0002_x4.png | ||
then, you need to set: | ||
.. code-block:: python | ||
ref_folder = 'data_root/ref' | ||
gt_folder = 'data_root/gt' | ||
lq_folder = 'data_root/lq' | ||
filename_tmpl_gt='{}' | ||
filename_tmpl_lq='{}_x4' | ||
Args: | ||
pipeline (List[dict | callable]): A sequence of data transformations. | ||
scale (int): Upsampling scale ratio. | ||
ref_folder (str | :obj:`Path`): Path to a ref folder. | ||
gt_folder (str | :obj:`Path` | None): Path to a gt folder. | ||
Default: None. | ||
lq_folder (str | :obj:`Path` | None): Path to a gt folder. | ||
Default: None. | ||
test_mode (bool): Store `True` when building test dataset. | ||
Default: `False`. | ||
filename_tmpl_gt (str): Template for gt filename. Note that the | ||
template excludes the file extension. Default: '{}'. | ||
filename_tmpl_lq (str): Template for lq filename. Note that the | ||
template excludes the file extension. Default: '{}'. | ||
""" | ||
|
||
def __init__(self, | ||
pipeline, | ||
scale, | ||
ref_folder, | ||
gt_folder=None, | ||
lq_folder=None, | ||
test_mode=False, | ||
filename_tmpl_gt='{}', | ||
filename_tmpl_lq='{}'): | ||
super().__init__(pipeline, scale, test_mode) | ||
assert gt_folder or lq_folder, 'At least one of gt_folder and' \ | ||
'lq_folder cannot be None.' | ||
self.scale = scale | ||
self.ref_folder = str(ref_folder) | ||
self.gt_folder = str(gt_folder) if gt_folder else None | ||
self.lq_folder = str(lq_folder) if lq_folder else None | ||
self.filename_tmpl_gt = filename_tmpl_gt | ||
self.filename_tmpl_lq = filename_tmpl_lq | ||
self.data_infos = self.load_annotations() | ||
|
||
def load_annotations(self): | ||
"""Load annoations for SR dataset. | ||
It loads the ref, LQ and GT image path from folders. | ||
Returns: | ||
dict: Returned dict for ref, LQ and GT pairs. | ||
""" | ||
data_infos = [] | ||
ref_paths = self.scan_folder(self.ref_folder) | ||
if self.gt_folder is not None: | ||
gt_paths = self.scan_folder(self.gt_folder) | ||
assert len(ref_paths) == len(gt_paths), ( | ||
f'ref and gt datasets have different number of images: ' | ||
f'{len(ref_paths)}, {len(gt_paths)}.') | ||
if self.lq_folder is not None: | ||
lq_paths = self.scan_folder(self.lq_folder) | ||
assert len(ref_paths) == len(lq_paths), ( | ||
f'ref and lq datasets have different number of images: ' | ||
f'{len(ref_paths)}, {len(lq_paths)}.') | ||
for ref_path in ref_paths: | ||
basename, ext = osp.splitext(osp.basename(ref_path)) | ||
data_dict = dict(ref_path=ref_path) | ||
if self.gt_folder is not None: | ||
gt_path = osp.join(self.gt_folder, | ||
(f'{self.filename_tmpl_gt.format(basename)}' | ||
f'{ext}')) | ||
assert gt_path in gt_paths, \ | ||
f'{gt_path} is not in gt_paths.' | ||
data_dict['gt_path'] = gt_path | ||
if self.lq_folder is not None: | ||
lq_path = osp.join(self.lq_folder, | ||
(f'{self.filename_tmpl_lq.format(basename)}' | ||
f'{ext}')) | ||
assert lq_path in lq_paths, \ | ||
f'{lq_path} is not in lq_paths.' | ||
data_dict['lq_path'] = lq_path | ||
data_infos.append(data_dict) | ||
return data_infos |
Oops, something went wrong.