-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #238 from ckkelvinchan/multiple-gt-dataset
Add MultipleGT datasets
- Loading branch information
Showing
15 changed files
with
344 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from .base_sr_dataset import BaseSRDataset | ||
from .registry import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class SRREDSMultipleGTDataset(BaseSRDataset): | ||
"""REDS dataset for video super resolution for recurrent networks. | ||
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) | ||
frames. Then it applies specified transforms and finally returns a dict | ||
containing paired data and other information. | ||
Args: | ||
lq_folder (str | :obj:`Path`): Path to a lq folder. | ||
gt_folder (str | :obj:`Path`): Path to a gt folder. | ||
num_input_frames (int): Number of input frames. | ||
pipeline (list[dict | callable]): A sequence of data transformations. | ||
scale (int): Upsampling scale ratio. | ||
val_partition (str): Validation partition mode. Choices ['official' or | ||
'REDS4']. Default: 'official'. | ||
test_mode (bool): Store `True` when building test dataset. | ||
Default: `False`. | ||
""" | ||
|
||
def __init__(self, | ||
lq_folder, | ||
gt_folder, | ||
num_input_frames, | ||
pipeline, | ||
scale, | ||
val_partition='official', | ||
test_mode=False): | ||
super().__init__(pipeline, scale, test_mode) | ||
self.lq_folder = str(lq_folder) | ||
self.gt_folder = str(gt_folder) | ||
self.num_input_frames = num_input_frames | ||
self.val_partition = val_partition | ||
self.data_infos = self.load_annotations() | ||
|
||
def load_annotations(self): | ||
"""Load annoations for REDS dataset. | ||
Returns: | ||
dict: Returned dict for LQ and GT pairs. | ||
""" | ||
# generate keys | ||
keys = [f'{i:03d}' for i in range(0, 270)] | ||
|
||
if self.val_partition == 'REDS4': | ||
val_partition = ['000', '011', '015', '020'] | ||
elif self.val_partition == 'official': | ||
val_partition = [f'{i:03d}' for i in range(240, 270)] | ||
else: | ||
raise ValueError( | ||
f'Wrong validation partition {self.val_partition}.' | ||
f'Supported ones are ["official", "REDS4"]') | ||
|
||
if self.test_mode: | ||
keys = [v for v in keys if v in val_partition] | ||
else: | ||
keys = [v for v in keys if v not in val_partition] | ||
|
||
data_infos = [] | ||
for key in keys: | ||
data_infos.append( | ||
dict( | ||
lq_path=self.lq_folder, | ||
gt_path=self.gt_folder, | ||
key=key, | ||
sequence_length=100, # REDS has 100 frames for each clip | ||
num_input_frames=self.num_input_frames)) | ||
|
||
return data_infos |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import glob | ||
import os.path as osp | ||
|
||
from .base_sr_dataset import BaseSRDataset | ||
from .registry import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class SRTestMultipleGTDataset(BaseSRDataset): | ||
"""Test dataset for video super resolution for recurrent networks. | ||
It assumes all video sequences under the root directory is used for test. | ||
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) | ||
frames. Then it applies specified transforms and finally returns a dict | ||
containing paired data and other information. | ||
Args: | ||
lq_folder (str | :obj:`Path`): Path to a lq folder. | ||
gt_folder (str | :obj:`Path`): Path to a gt folder. | ||
pipeline (list[dict | callable]): A sequence of data transformations. | ||
scale (int): Upsampling scale ratio. | ||
test_mode (bool): Store `True` when building test dataset. | ||
Default: `True`. | ||
""" | ||
|
||
def __init__(self, lq_folder, gt_folder, pipeline, scale, test_mode=True): | ||
super().__init__(pipeline, scale, test_mode) | ||
|
||
self.lq_folder = str(lq_folder) | ||
self.gt_folder = str(gt_folder) | ||
self.data_infos = self.load_annotations() | ||
|
||
def load_annotations(self): | ||
"""Load annoations for the test dataset. | ||
Returns: | ||
dict: Returned dict for LQ and GT pairs. | ||
""" | ||
|
||
sequences = sorted(glob.glob(osp.join(self.lq_folder, '*'))) | ||
|
||
data_infos = [] | ||
for sequence in sequences: | ||
sequence_length = len(glob.glob(osp.join(sequence, '*.png'))) | ||
data_infos.append( | ||
dict( | ||
lq_path=self.lq_folder, | ||
gt_path=self.gt_folder, | ||
key=sequence.replace(f'{self.lq_folder}/', ''), | ||
sequence_length=int(sequence_length))) | ||
|
||
return data_infos |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import glob | ||
import os.path as osp | ||
|
||
from .base_sr_dataset import BaseSRDataset | ||
from .registry import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class SRVimeo90KMultipleGTDataset(BaseSRDataset): | ||
"""Vimeo90K dataset for video super resolution for recurrent networks. | ||
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) | ||
frames. Then it applies specified transforms and finally returns a dict | ||
containing paired data and other information. | ||
It reads Vimeo90K keys from the txt file. Each line contains: | ||
1. video frame folder | ||
2. image shape | ||
Examples: | ||
:: | ||
00001/0266 (256,448,3) | ||
00001/0268 (256,448,3) | ||
Args: | ||
lq_folder (str | :obj:`Path`): Path to a lq folder. | ||
gt_folder (str | :obj:`Path`): Path to a gt folder. | ||
ann_file (str | :obj:`Path`): Path to the annotation file. | ||
pipeline (list[dict | callable]): A sequence of data transformations. | ||
scale (int): Upsampling scale ratio. | ||
test_mode (bool): Store `True` when building test dataset. | ||
Default: `False`. | ||
""" | ||
|
||
def __init__(self, lq_folder, gt_folder, ann_file, pipeline, scale, | ||
test_mode): | ||
super().__init__(pipeline, scale, test_mode) | ||
|
||
self.lq_folder = str(lq_folder) | ||
self.gt_folder = str(gt_folder) | ||
self.ann_file = str(ann_file) | ||
|
||
self.data_infos = self.load_annotations() | ||
|
||
def load_annotations(self): | ||
"""Load annoations for Vimeo-90K dataset. | ||
Returns: | ||
dict: Returned dict for LQ and GT pairs. | ||
""" | ||
# get keys | ||
with open(self.ann_file, 'r') as fin: | ||
keys = [line.strip().split(' ')[0] for line in fin] | ||
|
||
data_infos = [] | ||
for key in keys: | ||
lq_paths = sorted( | ||
glob.glob(osp.join(self.lq_folder, key, '*.png'))) | ||
gt_paths = sorted( | ||
glob.glob(osp.join(self.gt_folder, key, '*.png'))) | ||
|
||
data_infos.append( | ||
dict(lq_path=lq_paths, gt_path=gt_paths, key=key)) | ||
|
||
return data_infos |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters