forked from KupynOrest/RestoreGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
101 lines (83 loc) · 3.2 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from glob import glob
from typing import Optional
import cv2
import numpy as np
import torch
import yaml
from fire import Fire
from tqdm import tqdm
import albumentations as albu
from aug import get_normalize
from models.networks import get_generator
class Predictor:
def __init__(self, weights_path: str, model_name: str = ''):
with open('config/config.yaml') as cfg:
config = yaml.load(cfg)
model = get_generator(model_name or config['model'])
model.load_state_dict(torch.load(weights_path)['model'])
self.model = model.cuda()
self.model.train(True)
# GAN inference should be in train mode to use actual stats in norm layers,
# it's not a bug
self.normalize_fn = get_normalize()
@staticmethod
def _array_to_batch(x):
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0)
return torch.from_numpy(x)
def _preprocess(self, x: np.ndarray, mask: Optional[np.ndarray]):
x = albu.LongestMaxSize()(image=x)['image']
x, _ = self.normalize_fn(x, x)
if mask is None:
mask = np.ones_like(x, dtype=np.float32)
else:
mask = np.round(mask.astype('float32') / 255)
h, w, _ = x.shape
block_size = 32
min_height = (h // block_size + 1) * block_size
min_width = (w // block_size + 1) * block_size
pad_params = {'mode': 'constant',
'constant_values': 0,
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
}
x = np.pad(x, **pad_params)
mask = np.pad(mask, **pad_params)
return map(self._array_to_batch, (x, mask)), h, w
@staticmethod
def _postprocess(x: torch.Tensor) -> np.ndarray:
x, = x
x = x.detach().cpu().float().numpy()
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
return x.astype('uint8')
def __call__(self, img: np.ndarray, mask: Optional[np.ndarray], ignore_mask=True) -> np.ndarray:
(img, mask), h, w = self._preprocess(img, mask)
with torch.no_grad():
inputs = [img.cuda()]
if not ignore_mask:
inputs += [mask]
pred = self.model(*inputs)
return self._postprocess(pred)[:h, :w, :]
def main(img_pattern: str,
mask_pattern: Optional[str] = None,
weights_path='best_fpn.h5',
out_dir='submit/'
):
def sorted_glob(pattern):
return sorted(glob(pattern))
imgs = sorted_glob(img_pattern)
masks = sorted_glob(mask_pattern) if mask_pattern is not None else [None for _ in imgs]
pairs = zip(imgs, masks)
names = sorted([os.path.basename(x) for x in glob(img_pattern)])
predictor = Predictor(weights_path=weights_path)
os.makedirs(out_dir, exist_ok=True)
for name, pair in tqdm(zip(names, pairs), total=len(names)):
f_img, f_mask = pair
img, mask = map(cv2.imread, (f_img, f_mask))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pred = predictor(img, mask)
pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(out_dir, name),
pred)
if __name__ == '__main__':
Fire(main)