Skip to content

Commit

Permalink
Implement basicvsr (PaddlePaddle#356)
Browse files Browse the repository at this point in the history
* add basicvsr model
  • Loading branch information
LielinJiang authored Jul 6, 2021
1 parent 058faa8 commit 4bb2686
Show file tree
Hide file tree
Showing 11 changed files with 1,140 additions and 9 deletions.
92 changes: 92 additions & 0 deletions configs/basicvsr_reds.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
total_iters: 300000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)

model:
name: BasicVSRModel
fix_iter: 5000
generator:
name: BasicVSRNet
mid_channels: 64
num_blocks: 30
pixel_criterion:
name: CharbonnierLoss
reduction: mean

dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4 # 6
batch_size: 2 # 4*2
dataset:
name: SRREDSMultipleGTDataset
mode: train
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 15
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4

test:
name: SRREDSMultipleGTDataset
mode: test
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1

lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [300000]
restart_weights: [1]
eta_min: !!float 1e-7

optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99

validate:
# FIXME: avoid oom
interval: 5000000
save_img: false

metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False

log_config:
interval: 100
visiual_interval: 500

snapshot_config:
interval: 5000
1 change: 1 addition & 0 deletions ppgan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
from .edvr_dataset import REDSDataset
from .firstorder_dataset import FirstOrderDataset
from .lapstyle_dataset import LapStyleDataset
from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset
from .mpr_dataset import MPRTrain, MPRVal, MPRTest
20 changes: 16 additions & 4 deletions ppgan/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,33 @@

from paddle.distributed import ParallelEnv
from paddle.io import DistributedBatchSampler
from ..utils.registry import Registry

from .repeat_dataset import RepeatDataset
from ..utils.registry import Registry, build_from_config

DATASETS = Registry("DATASETS")


def build_dataset(cfg):
name = cfg.pop('name')

if name == 'RepeatDataset':
dataset_ = build_from_config(cfg['dataset'], DATASETS)
dataset = RepeatDataset(dataset_, cfg['times'])
else:
dataset = dataset = DATASETS.get(name)(**cfg)

return dataset


def build_dataloader(cfg, is_train=True, distributed=True):
cfg_ = cfg.copy()

batch_size = cfg_.pop('batch_size', 1)
num_workers = cfg_.pop('num_workers', 0)
use_shared_memory = cfg_.pop('use_shared_memory', True)

name = cfg_.pop('name')

dataset = DATASETS.get(name)(**cfg_)
dataset = build_dataset(cfg_)

if distributed:
sampler = DistributedBatchSampler(dataset,
Expand Down
51 changes: 51 additions & 0 deletions ppgan/datasets/repeat_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle


class RepeatDataset(paddle.io.Dataset):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""

def __init__(self, dataset, times):
self.dataset = dataset
self.times = times

self._ori_len = len(self.dataset)

def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
return self.dataset[idx % self._ori_len]

def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return self.times * self._ori_len
Loading

0 comments on commit 4bb2686

Please sign in to comment.