-
Notifications
You must be signed in to change notification settings - Fork 35
/
test.py
81 lines (57 loc) · 2.57 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
from PIL import Image
import torch
import torchvision.transforms as T
parser = argparse.ArgumentParser(description='Test inpainting')
parser.add_argument("--image", type=str,
default="examples/inpaint/case1.png", help="path to the image file")
parser.add_argument("--mask", type=str,
default="examples/inpaint/case1_mask.png", help="path to the mask file")
parser.add_argument("--out", type=str,
default="examples/inpaint/case1_out_test.png", help="path for the output file")
parser.add_argument("--checkpoint", type=str,
default="pretrained/states_tf_places2.pth", help="path to the checkpoint file")
def main():
args = parser.parse_args()
generator_state_dict = torch.load(args.checkpoint)['G']
if 'stage1.conv1.conv.weight' in generator_state_dict.keys():
from model.networks import Generator
else:
from model.networks_tf import Generator
use_cuda_if_available = True
device = torch.device('cuda' if torch.cuda.is_available()
and use_cuda_if_available else 'cpu')
# set up network
generator = Generator(cnum_in=5, cnum=48, return_flow=False).to(device)
generator_state_dict = torch.load(args.checkpoint)['G']
generator.load_state_dict(generator_state_dict, strict=True)
# load image and mask
image = Image.open(args.image)
mask = Image.open(args.mask)
# prepare input
image = T.ToTensor()(image)
mask = T.ToTensor()(mask)
_, h, w = image.shape
grid = 8
image = image[:3, :h//grid*grid, :w//grid*grid].unsqueeze(0)
mask = mask[0:1, :h//grid*grid, :w//grid*grid].unsqueeze(0)
print(f"Shape of image: {image.shape}")
image = (image*2 - 1.).to(device) # map image values to [-1, 1] range
mask = (mask > 0.5).to(dtype=torch.float32,
device=device) # 1.: masked 0.: unmasked
image_masked = image * (1.-mask) # mask image
ones_x = torch.ones_like(image_masked)[:, 0:1, :, :]
x = torch.cat([image_masked, ones_x, ones_x*mask],
dim=1) # concatenate channels
with torch.inference_mode():
_, x_stage2 = generator(x, mask)
# complete image
image_inpainted = image * (1.-mask) + x_stage2 * mask
# save inpainted image
img_out = ((image_inpainted[0].permute(1, 2, 0) + 1)*127.5)
img_out = img_out.to(device='cpu', dtype=torch.uint8)
img_out = Image.fromarray(img_out.numpy())
img_out.save(args.out)
print(f"Saved output file at: {args.out}")
if __name__ == '__main__':
main()