-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_frey.py
100 lines (77 loc) · 2.56 KB
/
generate_frey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from __future__ import print_function
import argparse
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import scipy.io
import string
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
#Standalone generator
#Load saved model and generate new faces
#PyTorch needs this class definition after loading
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(560, 200)
self.fc21 = nn.Linear(200, 20)
self.fc22 = nn.Linear(200, 20)
self.fc3 = nn.Linear(20, 200)
self.fc4 = nn.Linear(200, 560)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
#Add mu, sigma so that we can read them off later
#to generate data after we disconnect the encoder
self.mu_model = Variable(torch.cuda.FloatTensor(20), requires_grad=False)
self.logvar_model = Variable(torch.cuda.FloatTensor(20), requires_grad=False)
def encode(self, x):
h1 = self.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
eps = torch.cuda.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h3 = self.relu(self.fc3(z))
return self.sigmoid(self.fc4(h3))
def forward(self, x):
#mu, logvar = self.encode(x.view(-1, 560))
#self.mu_model = mu
#self.logvar_model = logvar
#z = self.reparametrize(mu, logvar)
return self.decode(z)
model = torch.load('./out/save.model')
print(model)
#get 100 random z's between 0 and -1
z = Variable(torch.randn(100,20))
print(z[0])
z = z.cuda()
x_gen = model.decode(z)
print('Generating new frey face manifold')
samples = x_gen.data.cpu().numpy()[:100]
fig = plt.figure(figsize=(10,10))
gs = gridspec.GridSpec(10,10)
gs.update(wspace=0.01,hspace=0.01)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28,20), cmap='gray')
if not os.path.exists('out/'):
os.makedirs('out/')
plt.savefig('out/generated_manifold.png',bbox_inches='tight', cmap='gray')
plt.close(fig)