Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultipleGT datasets #238

Merged
merged 1 commit into from
Apr 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mmedit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
from .sr_folder_gt_dataset import SRFolderGTDataset
from .sr_lmdb_dataset import SRLmdbDataset
from .sr_reds_dataset import SRREDSDataset
from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset
from .sr_test_multiple_gt_dataset import SRTestMultipleGTDataset
from .sr_vid4_dataset import SRVid4Dataset
from .sr_vimeo90k_dataset import SRVimeo90KDataset
from .sr_vimeo90k_multiple_gt_dataset import SRVimeo90KMultipleGTDataset

__all__ = [
'DATASETS', 'PIPELINES', 'build_dataset', 'build_dataloader',
'BaseDataset', 'BaseMattingDataset', 'ImgInpaintingDataset',
'AdobeComp1kDataset', 'SRLmdbDataset', 'SRFolderDataset',
'SRAnnotationDataset', 'BaseSRDataset', 'RepeatDataset', 'SRREDSDataset',
'SRVimeo90KDataset', 'BaseGenerationDataset', 'GenerationPairedDataset',
'GenerationUnpairedDataset', 'SRVid4Dataset', 'SRFolderGTDataset'
'GenerationUnpairedDataset', 'SRVid4Dataset', 'SRFolderGTDataset',
'SRREDSMultipleGTDataset', 'SRVimeo90KMultipleGTDataset',
'SRTestMultipleGTDataset'
]
73 changes: 73 additions & 0 deletions mmedit/datasets/sr_reds_multiple_gt_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS


@DATASETS.register_module()
class SRREDSMultipleGTDataset(BaseSRDataset):
"""REDS dataset for video super resolution for recurrent networks.

The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.

Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
num_input_frames (int): Number of input frames.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
val_partition (str): Validation partition mode. Choices ['official' or
'REDS4']. Default: 'official'.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""

def __init__(self,
lq_folder,
gt_folder,
num_input_frames,
pipeline,
scale,
val_partition='official',
test_mode=False):
super().__init__(pipeline, scale, test_mode)
self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.num_input_frames = num_input_frames
self.val_partition = val_partition
self.data_infos = self.load_annotations()

def load_annotations(self):
"""Load annoations for REDS dataset.

Returns:
dict: Returned dict for LQ and GT pairs.
"""
# generate keys
keys = [f'{i:03d}' for i in range(0, 270)]

if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif self.val_partition == 'official':
val_partition = [f'{i:03d}' for i in range(240, 270)]
else:
raise ValueError(
f'Wrong validation partition {self.val_partition}.'
f'Supported ones are ["official", "REDS4"]')

if self.test_mode:
keys = [v for v in keys if v in val_partition]
else:
keys = [v for v in keys if v not in val_partition]

data_infos = []
for key in keys:
data_infos.append(
dict(
lq_path=self.lq_folder,
gt_path=self.gt_folder,
key=key,
sequence_length=100, # REDS has 100 frames for each clip
num_input_frames=self.num_input_frames))

return data_infos
53 changes: 53 additions & 0 deletions mmedit/datasets/sr_test_multiple_gt_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import glob
import os.path as osp

from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS


@DATASETS.register_module()
class SRTestMultipleGTDataset(BaseSRDataset):
"""Test dataset for video super resolution for recurrent networks.

It assumes all video sequences under the root directory is used for test.

The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.

Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
test_mode (bool): Store `True` when building test dataset.
Default: `True`.
"""

def __init__(self, lq_folder, gt_folder, pipeline, scale, test_mode=True):
super().__init__(pipeline, scale, test_mode)

self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.data_infos = self.load_annotations()

def load_annotations(self):
"""Load annoations for the test dataset.

Returns:
dict: Returned dict for LQ and GT pairs.
"""

sequences = sorted(glob.glob(osp.join(self.lq_folder, '*')))

data_infos = []
for sequence in sequences:
sequence_length = len(glob.glob(osp.join(sequence, '*.png')))
data_infos.append(
dict(
lq_path=self.lq_folder,
gt_path=self.gt_folder,
key=sequence.replace(f'{self.lq_folder}/', ''),
sequence_length=int(sequence_length)))

return data_infos
68 changes: 68 additions & 0 deletions mmedit/datasets/sr_vimeo90k_multiple_gt_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import glob
import os.path as osp

from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS


@DATASETS.register_module()
class SRVimeo90KMultipleGTDataset(BaseSRDataset):
"""Vimeo90K dataset for video super resolution for recurrent networks.

The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.

It reads Vimeo90K keys from the txt file. Each line contains:

1. video frame folder
2. image shape

Examples:

::

00001/0266 (256,448,3)
00001/0268 (256,448,3)

Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
ann_file (str | :obj:`Path`): Path to the annotation file.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""

def __init__(self, lq_folder, gt_folder, ann_file, pipeline, scale,
test_mode):
super().__init__(pipeline, scale, test_mode)

self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.ann_file = str(ann_file)

self.data_infos = self.load_annotations()

def load_annotations(self):
"""Load annoations for Vimeo-90K dataset.

Returns:
dict: Returned dict for LQ and GT pairs.
"""
# get keys
with open(self.ann_file, 'r') as fin:
keys = [line.strip().split(' ')[0] for line in fin]

data_infos = []
for key in keys:
lq_paths = sorted(
glob.glob(osp.join(self.lq_folder, key, '*.png')))
gt_paths = sorted(
glob.glob(osp.join(self.gt_folder, key, '*.png')))

data_infos.append(
dict(lq_path=lq_paths, gt_path=gt_paths, key=key))

return data_infos
Binary file added tests/data/test_multiple_gt/sequence_1/im1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/test_multiple_gt/sequence_1/im2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/test_multiple_gt/sequence_2/im1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/vimeo90k/00001/0266/im7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
145 changes: 144 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
import pytest
from torch.utils.data import Dataset

# yapf: disable
from mmedit.datasets import (AdobeComp1kDataset, BaseGenerationDataset,
BaseSRDataset, GenerationPairedDataset,
GenerationUnpairedDataset, RepeatDataset,
SRAnnotationDataset, SRFolderDataset,
SRFolderGTDataset, SRLmdbDataset, SRREDSDataset,
SRVid4Dataset, SRVimeo90KDataset)
SRREDSMultipleGTDataset, SRTestMultipleGTDataset,
SRVid4Dataset, SRVimeo90KDataset,
SRVimeo90KMultipleGTDataset)

# yapf: enable


def mock_open(*args, **kwargs):
Expand Down Expand Up @@ -829,3 +834,141 @@ def test_vid4_dataset():
pipeline=[],
scale=4,
test_mode=False)


def test_sr_reds_multiple_gt_dataset():
root_path = Path(__file__).parent / 'data'

# official val partition
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=15,
pipeline=[],
scale=4,
val_partition='official',
test_mode=False)

assert len(reds_dataset.data_infos) == 240 # 240 training clips
assert reds_dataset.data_infos[0] == dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='000',
sequence_length=100,
num_input_frames=15)

# REDS4 val partition
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=20,
pipeline=[],
scale=4,
val_partition='REDS4',
test_mode=False)

assert len(reds_dataset.data_infos) == 266 # 266 training clips
assert reds_dataset.data_infos[0] == dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='001',
sequence_length=100,
num_input_frames=20) # 000 is been removed

with pytest.raises(ValueError):
# wrong val_partitaion
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=5,
pipeline=[],
scale=4,
val_partition='wrong_val_partition',
test_mode=False)

# test mode
# official val partition
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=5,
pipeline=[],
scale=4,
val_partition='official',
test_mode=True)

assert len(reds_dataset.data_infos) == 30 # 30 test clips
assert reds_dataset.data_infos[0] == dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='240',
sequence_length=100,
num_input_frames=5)

# REDS4 val partition
reds_dataset = SRREDSMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
num_input_frames=5,
pipeline=[],
scale=4,
val_partition='REDS4',
test_mode=True)

assert len(reds_dataset.data_infos) == 4 # 4 test clips
assert reds_dataset.data_infos[1] == dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='011',
sequence_length=100,
num_input_frames=5)


def test_sr_vimeo90k_mutiple_gt_dataset():
root_path = Path(__file__).parent / 'data/vimeo90k'

txt_content = ('00001/0266 (256,448,3)\n')
mocked_open_function = mock_open(read_data=txt_content)
lq_paths = [
str(root_path / '00001' / '0266' / f'im{v}.png') for v in range(1, 8)
]
gt_paths = [
str(root_path / '00001' / '0266' / f'im{v}.png') for v in range(1, 8)
]

with patch('builtins.open', mocked_open_function):
vimeo90k_dataset = SRVimeo90KMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
ann_file='fake_ann_file',
pipeline=[],
scale=4,
test_mode=False)

assert vimeo90k_dataset.data_infos == [
dict(lq_path=lq_paths, gt_path=gt_paths, key='00001/0266')
]


def test_sr_test_multiple_gt_dataset():
root_path = Path(__file__).parent / 'data/test_multiple_gt'

test_dataset = SRTestMultipleGTDataset(
lq_folder=root_path,
gt_folder=root_path,
pipeline=[],
scale=4,
test_mode=True)

assert test_dataset.data_infos == [
dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='sequence_1',
sequence_length=2),
dict(
lq_path=str(root_path),
gt_path=str(root_path),
key='sequence_2',
sequence_length=1)
]