-
Notifications
You must be signed in to change notification settings - Fork 2
/
vae.py
109 lines (93 loc) · 3.75 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# based on https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
import torch
from torch import nn
from torchvision import transforms
CELEB_PATH = './data/'
IMAGE_SIZE = 150
LATENT_DIM = 128
image_dim = 3 * IMAGE_SIZE * IMAGE_SIZE # 67500
print('\nCELEB_PATH', CELEB_PATH,
'IMAGE_SIZE', IMAGE_SIZE,
'LATENT_DIM', LATENT_DIM,
'image_dim', image_dim)
celeb_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE, antialias=True),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor()]) # used when transforming image to tensor
celeb_transform1 = transforms.Compose([
transforms.Resize(IMAGE_SIZE, antialias=True),
transforms.CenterCrop(IMAGE_SIZE)]) # used by decode method to transform final output
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
hidden_dims = [32, 64, 128, 256, 512]
self.final_dim = hidden_dims[-1]
in_channels = 3
modules = []
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
out = self.encoder(torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE))
self.size = out.shape[2]
self.fc_mu = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)
self.fc_var = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(LATENT_DIM, hidden_dims[-1] * self.size * self.size)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels=3,
kernel_size=3, padding=1),
nn.Sigmoid())
def encode(self, x):
result = self.encoder(x)
result = torch.flatten(result, start_dim=1)
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def decode(self, z):
result = self.decoder_input(z)
result = result.view(-1, self.final_dim, self.size, self.size)
result = self.decoder(result)
result = self.final_layer(result)
result = celeb_transform1(result)
result = torch.flatten(result, start_dim=1)
result = torch.nan_to_num(result)
return result
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var