-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathsemantic_edit.py
executable file
·94 lines (51 loc) · 2.57 KB
/
semantic_edit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import numpy as np
import matplotlib.pyplot as plt
from stylegan_layers import G_mapping,G_synthesis
from read_image import image_reader
import argparse
import torch
import torch.nn as nn
from collections import OrderedDict
import torch.nn.functional as F
from torchvision.utils import save_image
from perceptual_model import VGG16_for_Perceptual
import torch.optim as optim
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
def main():
parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual loss')
parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
parser.add_argument('--resolution',default=1024,type=int)
parser.add_argument('--weight_file',default="weight_files/pytorch/karras2019stylegan-ffhq-1024x1024.pt",type=str)
parser.add_argument('--latent_file',default="latent_W/0.npy")
args=parser.parse_args()
g_all = nn.Sequential(OrderedDict([
('g_mapping', G_mapping()),
#('truncation', Truncation(avg_latent)),
('g_synthesis', G_synthesis(resolution=args.resolution))
]))
g_all.load_state_dict(torch.load(args.weight_file, map_location=device))
g_all.eval()
g_all.to(device)
g_mapping,g_synthesis=g_all[0],g_all[1]
boundary_name=["stylegan_ffhq_gender_w_boundary.npy","stylegan_ffhq_age_w_boundary.npy","stylegan_ffhq_pose_w_boundary.npy","stylegan_ffhq_eyeglasses_w_boundary.npy","stylegan_ffhq_smile_w_boundary.npy"]
semantic=["gender","age","pose","eye_glass","smile"]
for i in range(5):
latents_0=np.load(args.latent_file)
latents_0=torch.tensor(latents_0).to(device)#.unsqueeze(0)
boundary=np.load("boundaries/"+boundary_name[i])
make_morph(boundary,i,latents_0,g_synthesis,semantic)
def make_morph(boundary,i,latents_0,g_synthesis,semantic):
boundary=boundary.reshape(1,1,-1)
linspace = np.linspace(-3, 3, 10)
linspace=linspace.reshape(-1,1,1).astype(np.float32)
boundary=torch.tensor(boundary).to(device)
linspace=torch.tensor(linspace).to(device)
latent_code=latents_0+linspace*boundary
latent_code=latent_code.to(torch.float)
with torch.no_grad():
synth_img=g_synthesis(latent_code)
synth_img = (synth_img + 1.0) / 2.0
save_image(synth_img,"save_image/boundary/{}.png".format(semantic[i]))
if __name__ == "__main__":
main()