-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
159 lines (113 loc) · 4.33 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Some code was borrowed from https://github.com/hwalsuklee/tensorflow-mnist-VAE/blob/master/vae.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, img_dim, hidden_dim, z_dim, dropout = 0.5):
super(Encoder, self).__init__()
self.img_dim = img_dim
self.hidden_dim = hidden_dim
self.z_dim = z_dim
self.drop = nn.Dropout(p = dropout)
self.linear1 = nn.Linear(
self.img_dim,
self.hidden_dim
)
self.linear2 = nn.Linear(
self.hidden_dim,
self.hidden_dim,
)
# In order to make mean, stddev, we take output dimension equals 2 * self.z_dim
self.out_layer = nn.Linear(
self.hidden_dim,
self.z_dim * 2
)
# He Initialization
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
nn.init.constant_(m.bias, 0.0)
def forward(self, x):
x = F.relu(self.linear1(x))
x = self.drop(x)
x = F.relu(self.linear2(x))
x = self.drop(x)
params = self.out_layer(x)
# mu : mean, sigma : standard deviation
mu = params[:, :self.z_dim]
# The standard deviation must be positive. Parametrize with a softplus and
# add a small epsilon for numerical stability
sigma = 1e-6 + F.softplus(params[:, self.z_dim:])
return mu, sigma
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim, img_dim, dropout = 0.5):
super(Decoder, self).__init__()
self.z_dim = z_dim
self.hidden_dim = hidden_dim
self.img_dim = img_dim
self.drop = nn.Dropout(p = dropout)
self.linear1 = nn.Linear(
self.z_dim,
self.hidden_dim
)
self.linear2 = nn.Linear(
self.hidden_dim,
self.hidden_dim,
)
self.linear3 = nn.Linear(
self.hidden_dim,
self.img_dim
)
# He Initialization
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
nn.init.constant_(m.bias, 0.0)
def forward(self, x):
x = F.relu(self.linear1(x))
x = self.drop(x)
x = F.relu(self.linear2(x))
x = self.drop(x)
x = torch.sigmoid(self.linear3(x))
return x
# In order to use backpropagation, reparametrize sampling
class VAE(nn.Module):
def __init__(self, img_dim, z_dim, hidden_dim, dropout):
super().__init__()
self.img_dim = img_dim
self.z_dim = z_dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.encoder = Encoder(
self.img_dim,
self.hidden_dim,
self.z_dim,
self.dropout
)
self.decoder = Decoder(
self.z_dim,
self.hidden_dim,
self.img_dim,
self.dropout
)
def forward(self, x):
mu, sigma = self.encoder(x)
z = mu + sigma * torch.randn_like(mu)
x_target = self.decoder(z)
return x_target, mu, sigma
def vae_loss(x_target, x, mu, sigma):
batch_size = x.size(0)
generative_loss = F.binary_cross_entropy(x_target, x, reduction = 'sum')
KLD_loss = 0.5 * torch.sum(
torch.pow(mu, 2) +
torch.pow(sigma, 2) -
torch.log(1e-8 + torch.pow(sigma, 2)) -1
).sum() / batch_size
loss = generative_loss + KLD_loss
return loss
# a = torch.rand(10, 30)
# b = torch.rand(10, 30)
# c = torch.rand(10, 30)
# model = VAE(a,b, 30, 40, 40, 0.4)
# model.eval()
# print(model.get_ae(c))