Skip to content

Commit

Permalink
Merge pull request #238 from ckkelvinchan/multiple-gt-dataset
Browse files Browse the repository at this point in the history
Add MultipleGT datasets
  • Loading branch information
nbei authored Apr 11, 2021
2 parents 3c26d34 + 91480a2 commit 9243ffc
Show file tree
Hide file tree
Showing 15 changed files with 344 additions and 2 deletions.
7 changes: 6 additions & 1 deletion mmedit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
from .sr_folder_gt_dataset import SRFolderGTDataset
from .sr_lmdb_dataset import SRLmdbDataset
from .sr_reds_dataset import SRREDSDataset
from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset
from .sr_test_multiple_gt_dataset import SRTestMultipleGTDataset
from .sr_vid4_dataset import SRVid4Dataset
from .sr_vimeo90k_dataset import SRVimeo90KDataset
from .sr_vimeo90k_multiple_gt_dataset import SRVimeo90KMultipleGTDataset

__all__ = [
'DATASETS', 'PIPELINES', 'build_dataset', 'build_dataloader',
'BaseDataset', 'BaseMattingDataset', 'ImgInpaintingDataset',
'AdobeComp1kDataset', 'SRLmdbDataset', 'SRFolderDataset',
'SRAnnotationDataset', 'BaseSRDataset', 'RepeatDataset', 'SRREDSDataset',
'SRVimeo90KDataset', 'BaseGenerationDataset', 'GenerationPairedDataset',
'GenerationUnpairedDataset', 'SRVid4Dataset', 'SRFolderGTDataset'
'GenerationUnpairedDataset', 'SRVid4Dataset', 'SRFolderGTDataset',
'SRREDSMultipleGTDataset', 'SRVimeo90KMultipleGTDataset',
'SRTestMultipleGTDataset'
]
73 changes: 73 additions & 0 deletions mmedit/datasets/sr_reds_multiple_gt_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS


@DATASETS.register_module()
class SRREDSMultipleGTDataset(BaseSRDataset):
"""REDS dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
num_input_frames (int): Number of input frames.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
val_partition (str): Validation partition mode. Choices ['official' or
'REDS4']. Default: 'official'.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""

def __init__(self,
lq_folder,
gt_folder,
num_input_frames,
pipeline,
scale,
val_partition='official',
test_mode=False):
super().__init__(pipeline, scale, test_mode)
self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.num_input_frames = num_input_frames
self.val_partition = val_partition
self.data_infos = self.load_annotations()

def load_annotations(self):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# generate keys
keys = [f'{i:03d}' for i in range(0, 270)]

if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif self.val_partition == 'official':
val_partition = [f'{i:03d}' for i in range(240, 270)]
else:
raise ValueError(
f'Wrong validation partition {self.val_partition}.'
f'Supported ones are ["official", "REDS4"]')

if self.test_mode:
keys = [v for v in keys if v in val_partition]
else:
keys = [v for v in keys if v not in val_partition]

data_infos = []
for key in keys:
data_infos.append(
dict(
lq_path=self.lq_folder,
gt_path=self.gt_folder,
key=key,
sequence_length=100, # REDS has 100 frames for each clip
num_input_frames=self.num_input_frames))

return data_infos
53 changes: 53 additions & 0 deletions mmedit/datasets/sr_test_multiple_gt_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import glob
import os.path as osp

from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS


@DATASETS.register_module()
class SRTestMultipleGTDataset(BaseSRDataset):
"""Test dataset for video super resolution for recurrent networks.
It assumes all video sequences under the root directory is used for test.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
test_mode (bool): Store `True` when building test dataset.
Default: `True`.
"""

def __init__(self, lq_folder, gt_folder, pipeline, scale, test_mode=True):
super().__init__(pipeline, scale, test_mode)

self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.data_infos = self.load_annotations()

def load_annotations(self):
"""Load annoations for the test dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""

sequences = sorted(glob.glob(osp.join(self.lq_folder, '*')))

data_infos = []
for sequence in sequences:
sequence_length = len(glob.glob(osp.join(sequence, '*.png')))
data_infos.append(
dict(
lq_path=self.lq_folder,
gt_path=self.gt_folder,
key=sequence.replace(f'{self.lq_folder}/', ''),
sequence_length=int(sequence_length)))

return data_infos
68 changes: 68 additions & 0 deletions mmedit/datasets/sr_vimeo90k_multiple_gt_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import glob
import os.path as osp

from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS


@DATASETS.register_module()
class SRVimeo90KMultipleGTDataset(BaseSRDataset):
"""Vimeo90K dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
It reads Vimeo90K keys from the txt file. Each line contains:
1. video frame folder
2. image shape
Examples:
::
00001/0266 (256,448,3)
00001/0268 (256,448,3)
Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
ann_file (str | :obj:`Path`): Path to the annotation file.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""

def __init__(self, lq_folder, gt_folder, ann_file, pipeline, scale,
test_mode):
super().__init__(pipeline, scale, test_mode)

self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.ann_file = str(ann_file)

self.data_infos = self.load_annotations()

def load_annotations(self):
"""Load annoations for Vimeo-90K dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# get keys
with open(self.ann_file, 'r') as fin:
keys = [line.strip().split(' ')[0] for line in fin]

data_infos = []
for key in keys:
lq_paths = sorted(
glob.glob(osp.join(self.lq_folder, key, '*.png')))
gt_paths = sorted(
glob.glob(osp.join(self.gt_folder, key, '*.png')))

data_infos.append(
dict(lq_path=lq_paths, gt_path=gt_paths, key=key))

return data_infos
Binary file added tests/data/test_multiple_gt/sequence_1/im1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/test_multiple_gt/sequence_1/im2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/test_multiple_gt/sequence_2/im1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
145 changes: 144 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
import pytest
from torch.utils.data import Dataset

# yapf: disable
from mmedit.datasets import (AdobeComp1kDataset, BaseGenerationDataset,
BaseSRDataset, GenerationPairedDataset,
GenerationUnpairedDataset, RepeatDataset,
SRAnnotationDataset, SRFolderDataset,
SRFolderGTDataset, SRLmdbDataset, SRREDSDataset,
SRVid4Dataset, SRVimeo90KDataset)
SRREDSMultipleGTDataset, SRTestMultipleGTDataset,
SRVid4Dataset, SRVimeo90KDataset,
SRVimeo90KMultipleGTDataset)

# yapf: enable


def mock_open(*args, **kwargs):
Expand Down Expand Up @@ -829,3 +834,141 @@ def test_vid4_dataset():
pipeline=[],
scale=4,
test_mode=False)


def test_sr_reds_multiple_gt_dataset():
root_path = Path(__file__).parent / 'data'

# official val partition
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=15,
pipeline=[],
scale=4,
val_partition='official',
test_mode=False)

assert len(reds_dataset.data_infos) == 240 # 240 training clips
assert reds_dataset.data_infos[0] == dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='000',
sequence_length=100,
num_input_frames=15)

# REDS4 val partition
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=20,
pipeline=[],
scale=4,
val_partition='REDS4',
test_mode=False)

assert len(reds_dataset.data_infos) == 266 # 266 training clips
assert reds_dataset.data_infos[0] == dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='001',
sequence_length=100,
num_input_frames=20) # 000 is been removed

with pytest.raises(ValueError):
# wrong val_partitaion
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=5,
pipeline=[],
scale=4,
val_partition='wrong_val_partition',
test_mode=False)

# test mode
# official val partition
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=5,
pipeline=[],
scale=4,
val_partition='official',
test_mode=True)

assert len(reds_dataset.data_infos) == 30 # 30 test clips
assert reds_dataset.data_infos[0] == dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='240',
sequence_length=100,
num_input_frames=5)

# REDS4 val partition
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=5,
pipeline=[],
scale=4,
val_partition='REDS4',
test_mode=True)

assert len(reds_dataset.data_infos) == 4 # 4 test clips
assert reds_dataset.data_infos[1] == dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='011',
sequence_length=100,
num_input_frames=5)


def test_sr_vimeo90k_mutiple_gt_dataset():
root_path = Path(__file__).parent / 'data/vimeo90k'

txt_content = ('00001/0266 (256,448,3)\n')
mocked_open_function = mock_open(read_data=txt_content)
lq_paths = [
str(root_path / '00001' / '0266' / f'im{v}.png') for v in range(1, 8)
]
gt_paths = [
str(root_path / '00001' / '0266' / f'im{v}.png') for v in range(1, 8)
]

with patch('builtins.open', mocked_open_function):
vimeo90k_dataset = SRVimeo90KMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
ann_file='fake_ann_file',
pipeline=[],
scale=4,
test_mode=False)

assert vimeo90k_dataset.data_infos == [
dict(lq_path=lq_paths, gt_path=gt_paths, key='00001/0266')
]


def test_sr_test_multiple_gt_dataset():
root_path = Path(__file__).parent / 'data/test_multiple_gt'

test_dataset = SRTestMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
pipeline=[],
scale=4,
test_mode=True)

assert test_dataset.data_infos == [
dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='sequence_1',
sequence_length=2),
dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='sequence_2',
sequence_length=1)
]

0 comments on commit 9243ffc

Please sign in to comment.