Skip to content

Commit d9e6ad1

Browse files
author
juingzhou
authored
Initial commit
1 parent 8806145 commit d9e6ad1

8 files changed

+489
-0
lines changed

README.md

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# ESPCN
2+
3+
This repository is implementation of the ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network"](https://arxiv.org/abs/1609.05158).
4+
5+
<center><img src="./thumbnails/fig1.png"></center>
6+
7+
## Requirements
8+
9+
- PyTorch 1.1.0
10+
- Numpy 1.15.4
11+
- Pillow 6.0.0
12+
- h5py 2.8.0
13+
- tqdm 4.30.0
14+
15+
## Train
16+
17+
The 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below.
18+
19+
| Dataset | Scale | Type | Link |
20+
|---------|-------|------|------|
21+
| 91-image | 3 | Train | [Link](/BLAH_BLAH/) |
22+
| Set5 | 3 | Eval | [Link](/BLAH_BLAH/) |
23+
24+
Otherwise, you can use `prepare.py` to create custom dataset.
25+
26+
```bash
27+
bash run.sh
28+
```
29+
30+
## Test
31+
32+
Pre-trained weights can be downloaded from the links below.
33+
34+
| Model | Scale | Link |
35+
|-------|-------|------|
36+
| ESPCN (91) | 3 | [Link](/BLAH_BLAH/outputs/x3/best.pth) |
37+
38+
The results are stored in the same path as the query image.
39+
40+
```bash
41+
bash run.sh
42+
```
43+
44+
## Results
45+
46+
PSNR was calculated on the Y channel.
47+
48+
### Set5
49+
50+
| Eval. Mat | Scale | Paper (91) | Ours (91) |
51+
|-----------|-------|-------|-----------------|
52+
| PSNR | 3 | 32.55 | 32.88 |
53+
54+
<table>
55+
<tr>
56+
<td><center>Original</center></td>
57+
<td><center>BICUBIC x3</center></td>
58+
<td><center>ESPCN x3 (23.84 dB)</center></td>
59+
</tr>
60+
<tr>
61+
<td>
62+
<center><img src="./data/baboon.bmp""></center>
63+
</td>
64+
<td>
65+
<center><img src="./data/baboon_bicubic_x3.bmp"></center>
66+
</td>
67+
<td>
68+
<center><img src="./data/baboon_espcn_x3.bmp"></center>
69+
</td>
70+
</tr>
71+
<tr>
72+
<td><center>Original</center></td>
73+
<td><center>BICUBIC x3</center></td>
74+
<td><center>ESPCN x3 (25.32 dB)</center></td>
75+
</tr>
76+
<tr>
77+
<td>
78+
<center><img src="./data/comic.bmp""></center>
79+
</td>
80+
<td>
81+
<center><img src="./data/comic_bicubic_x3.bmp"></center>
82+
</td>
83+
<td>
84+
<center><img src="./data/comic_espcn_x3.bmp"></center>
85+
</td>
86+
</tr>
87+
</table>

datasets.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import h5py
2+
import numpy as np
3+
from torch.utils.data import Dataset
4+
5+
6+
class TrainDataset(Dataset):
7+
def __init__(self, h5_file):
8+
super(TrainDataset, self).__init__()
9+
self.h5_file = h5_file
10+
11+
def __getitem__(self, idx):
12+
with h5py.File(self.h5_file, 'r') as f:
13+
return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
14+
15+
def __len__(self):
16+
with h5py.File(self.h5_file, 'r') as f:
17+
return len(f['lr'])
18+
19+
20+
class EvalDataset(Dataset):
21+
def __init__(self, h5_file):
22+
super(EvalDataset, self).__init__()
23+
self.h5_file = h5_file
24+
25+
def __getitem__(self, idx):
26+
with h5py.File(self.h5_file, 'r') as f:
27+
return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)
28+
29+
def __len__(self):
30+
with h5py.File(self.h5_file, 'r') as f:
31+
return len(f['lr'])

models.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import math
2+
from torch import nn
3+
4+
5+
class ESPCN(nn.Module):
6+
def __init__(self, scale_factor, num_channels=1):
7+
super(ESPCN, self).__init__()
8+
self.first_part = nn.Sequential(
9+
nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
10+
nn.Tanh(),
11+
nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
12+
nn.Tanh(),
13+
)
14+
self.last_part = nn.Sequential(
15+
nn.Conv2d(32, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
16+
nn.PixelShuffle(scale_factor)
17+
)
18+
19+
self._initialize_weights()
20+
21+
def _initialize_weights(self):
22+
for m in self.modules():
23+
if isinstance(m, nn.Conv2d):
24+
if m.in_channels == 32:
25+
nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
26+
nn.init.zeros_(m.bias.data)
27+
else:
28+
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
29+
nn.init.zeros_(m.bias.data)
30+
31+
def forward(self, x):
32+
x = self.first_part(x)
33+
x = self.last_part(x)
34+
return x
35+
36+

prepare.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import argparse
2+
import glob
3+
import h5py
4+
import numpy as np
5+
import PIL.Image as pil_image
6+
from utils import convert_rgb_to_y
7+
8+
9+
def train(args):
10+
h5_file = h5py.File(args.output_path, 'w')
11+
12+
lr_patches = []
13+
hr_patches = []
14+
15+
for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
16+
hr = pil_image.open(image_path).convert('RGB')
17+
hr_width = (hr.width // args.scale) * args.scale
18+
hr_height = (hr.height // args.scale) * args.scale
19+
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
20+
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
21+
hr = np.array(hr).astype(np.float32)
22+
lr = np.array(lr).astype(np.float32)
23+
hr = convert_rgb_to_y(hr)
24+
lr = convert_rgb_to_y(lr)
25+
26+
for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
27+
for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
28+
lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
29+
hr_patches.append(hr[i * args.scale:i * args.scale + args.patch_size * args.scale, j * args.scale:j * args.scale + args.patch_size * args.scale])
30+
31+
lr_patches = np.array(lr_patches)
32+
hr_patches = np.array(hr_patches)
33+
34+
h5_file.create_dataset('lr', data=lr_patches)
35+
h5_file.create_dataset('hr', data=hr_patches)
36+
37+
h5_file.close()
38+
39+
40+
def eval(args):
41+
h5_file = h5py.File(args.output_path, 'w')
42+
43+
lr_group = h5_file.create_group('lr')
44+
hr_group = h5_file.create_group('hr')
45+
46+
for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
47+
hr = pil_image.open(image_path).convert('RGB')
48+
hr_width = (hr.width // args.scale) * args.scale
49+
hr_height = (hr.height // args.scale) * args.scale
50+
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
51+
lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
52+
hr = np.array(hr).astype(np.float32)
53+
lr = np.array(lr).astype(np.float32)
54+
hr = convert_rgb_to_y(hr)
55+
lr = convert_rgb_to_y(lr)
56+
57+
lr_group.create_dataset(str(i), data=lr)
58+
hr_group.create_dataset(str(i), data=hr)
59+
60+
h5_file.close()
61+
62+
63+
if __name__ == '__main__':
64+
parser = argparse.ArgumentParser()
65+
parser.add_argument('--images-dir', type=str, required=True)
66+
parser.add_argument('--output-path', type=str, required=True)
67+
parser.add_argument('--scale', type=int, default=3)
68+
parser.add_argument('--patch-size', type=int, default=17)
69+
parser.add_argument('--stride', type=int, default=13)
70+
parser.add_argument('--eval', action='store_true')
71+
args = parser.parse_args()
72+
73+
if not args.eval:
74+
train(args)
75+
else:
76+
eval(args)

run.sh

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Train step:
2+
# python train.py --train-file "BLAH_BLAH/91-image_x3.h5" --eval-file "BLAH_BLAH/Set5_x3.h5" --outputs-dir "BLAH_BLAH/outputs" --scale 3 --lr 1e-3 --batch-size 32 --num-epochs 200 --num-workers 8 --seed 123
3+
# Test step:
4+
python test.py --weights-file "BLAH_BLAH/outputs/x3/best.pth" --image-file "data/baboon.bmp" --scale 3

test.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import argparse
2+
3+
import torch
4+
import torch.backends.cudnn as cudnn
5+
import numpy as np
6+
import PIL.Image as pil_image
7+
8+
from models import ESPCN
9+
from utils import convert_ycbcr_to_rgb, preprocess, calc_psnr
10+
11+
12+
if __name__ == '__main__':
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument('--weights-file', type=str, required=True)
15+
parser.add_argument('--image-file', type=str, required=True)
16+
parser.add_argument('--scale', type=int, default=3)
17+
args = parser.parse_args()
18+
19+
cudnn.benchmark = True
20+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
21+
22+
model = ESPCN(scale_factor=args.scale).to(device)
23+
24+
state_dict = model.state_dict()
25+
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
26+
if n in state_dict.keys():
27+
state_dict[n].copy_(p)
28+
else:
29+
raise KeyError(n)
30+
31+
model.eval()
32+
33+
image = pil_image.open(args.image_file).convert('RGB')
34+
35+
image_width = (image.width // args.scale) * args.scale
36+
image_height = (image.height // args.scale) * args.scale
37+
38+
hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
39+
lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
40+
bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
41+
bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
42+
43+
lr, _ = preprocess(lr, device)
44+
hr, _ = preprocess(hr, device)
45+
_, ycbcr = preprocess(bicubic, device)
46+
47+
with torch.no_grad():
48+
preds = model(lr).clamp(0.0, 1.0)
49+
50+
psnr = calc_psnr(hr, preds)
51+
print('PSNR: {:.2f}'.format(psnr))
52+
53+
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
54+
55+
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
56+
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
57+
output = pil_image.fromarray(output)
58+
output.save(args.image_file.replace('.', '_espcn_x{}.'.format(args.scale)))

0 commit comments

Comments
 (0)