diff --git a/configs/restorers/basicvsr_plusplus/basicvsr_plusplus_c64n7_8x1_600k_reds4.py b/configs/restorers/basicvsr_plusplus/basicvsr_plusplus_c64n7_8x1_600k_reds4.py index 72ee695dfd..466030621b 100644 --- a/configs/restorers/basicvsr_plusplus/basicvsr_plusplus_c64n7_8x1_600k_reds4.py +++ b/configs/restorers/basicvsr_plusplus/basicvsr_plusplus_c64n7_8x1_600k_reds4.py @@ -102,6 +102,7 @@ pipeline=test_pipeline, scale=4, val_partition='REDS4', + repeat=2, test_mode=True), # test test=dict( diff --git a/mmedit/datasets/sr_reds_multiple_gt_dataset.py b/mmedit/datasets/sr_reds_multiple_gt_dataset.py index 770bfeb03c..040752f958 100644 --- a/mmedit/datasets/sr_reds_multiple_gt_dataset.py +++ b/mmedit/datasets/sr_reds_multiple_gt_dataset.py @@ -19,6 +19,9 @@ class SRREDSMultipleGTDataset(BaseSRDataset): scale (int): Upsampling scale ratio. val_partition (str): Validation partition mode. Choices ['official' or 'REDS4']. Default: 'official'. + repeat (int): Number of replication of the validation set. This is used + to allow training REDS4 with more than 4 GPUs. For example, if + 8 GPUs are used, this number can be set to 2. Default: 1. test_mode (bool): Store `True` when building test dataset. Default: `False`. """ @@ -30,7 +33,14 @@ def __init__(self, pipeline, scale, val_partition='official', + repeat=1, test_mode=False): + + self.repeat = repeat + if not isinstance(repeat, int): + raise TypeError('"repeat" must be an integer, but got ' + f'{type(repeat)}.') + super().__init__(pipeline, scale, test_mode) self.lq_folder = str(lq_folder) self.gt_folder = str(gt_folder) @@ -58,6 +68,7 @@ def load_annotations(self): if self.test_mode: keys = [v for v in keys if v in val_partition] + keys *= self.repeat else: keys = [v for v in keys if v not in val_partition] diff --git a/tests/test_data/test_datasets/test_sr_dataset.py b/tests/test_data/test_datasets/test_sr_dataset.py index e442ddd2b5..e3cb36b71b 100644 --- a/tests/test_data/test_datasets/test_sr_dataset.py +++ b/tests/test_data/test_datasets/test_sr_dataset.py @@ -788,6 +788,37 @@ def test_sr_reds_multiple_gt_dataset(): sequence_length=100, num_input_frames=5) + # REDS4 val partition (repeat > 1) + reds_dataset = SRREDSMultipleGTDataset( + lq_folder=root_path, + gt_folder=root_path, + num_input_frames=5, + pipeline=[], + scale=4, + val_partition='REDS4', + repeat=2, + test_mode=True) + + assert len(reds_dataset.data_infos) == 8 # 4 test clips + assert reds_dataset.data_infos[5] == dict( + lq_path=str(root_path), + gt_path=str(root_path), + key='011', + sequence_length=100, + num_input_frames=5) + + # REDS4 val partition (repeat != int) + with pytest.raises(TypeError): + SRREDSMultipleGTDataset( + lq_folder=root_path, + gt_folder=root_path, + num_input_frames=5, + pipeline=[], + scale=4, + val_partition='REDS4', + repeat=1.5, + test_mode=True) + def test_sr_vimeo90k_mutiple_gt_dataset(): root_path = Path(__file__).parent.parent.parent / 'data/vimeo90k'