-
Notifications
You must be signed in to change notification settings - Fork 2
/
MAD_VAE.py
108 lines (93 loc) · 3.71 KB
/
MAD_VAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Normal
# Class for convolution block
class ConvBlock(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size=0, stride=1, padding=0):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False)
self.norm = nn.BatchNorm2d(out_dim)
self.relu = nn.ReLU(True)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.relu(x)
return x
# Class for de-convolution block
class DeConvBlock(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size=0, stride=1, padding=0, out_padding=0):
super(DeConvBlock, self).__init__()
self.conv = nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, padding, out_padding, bias=False)
self.norm = nn.BatchNorm2d(out_dim)
self.relu = nn.ReLU(True)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.relu(x)
return x
# Class for residual conv block
class ResidualConvBlock(nn.Module):
def __init__(self, dim, kernel_size=0, stride=1, padding=0):
super(ResidualConvBlock, self).__init__()
self.pad1 = nn.ReplicationPad2d(padding)
self.conv1 = nn.Conv2d(dim, dim, kernel_size, stride)
self.norm1 = nn.InstanceNorm2d(dim)
self.relu1 = nn.ReLU(True)
self.pad2 = nn.ReplicationPad2d(padding)
self.conv2 = nn.Conv2d(dim, dim, kernel_size, stride)
self.norm2 = nn.InstanceNorm2d(dim)
self.relu2 = nn.ReLU(True)
self.module = nn.Sequential(self.pad1, self.conv1, self.norm1, self.relu1, self.pad2, self.conv2, self.norm2, self.relu2)
def forward(self, x):
return x + self.module(x)
# Main class for MAD-VAE
class MADVAE(nn.Module):
def __init__(self, args):
super(MADVAE, self).__init__()
self.model_name = 'DAD-VAE'
self.image_size = args.image_size
self.image_channels = args.image_channels
self.h_dim = args.h_dim
self.z_dim = args.z_dim
# module for encoder
self.c1 = ConvBlock(self.image_channels, 64, 5, 1, 2)
self.c2 = ConvBlock(64, 64, 4, 2, 3)
self.c3 = ConvBlock(64, 128, 4, 2, 1)
self.c4 = ConvBlock(128, 256, 4, 2, 1)
self.e_module = nn.Sequential(self.c1, self.c2, self.c3, self.c4)
self.mu =nn.Linear(self.h_dim, self.z_dim)
self.sigma = nn.Linear(self.h_dim, self.z_dim)
# module for image decoder
self.linear = nn.Linear(self.z_dim, self.h_dim)
self.d1 = DeConvBlock(256, 128, 4, 2, 1)
self.d2 = DeConvBlock(128, 64, 4, 2, 1)
self.d3 = DeConvBlock(64, 64, 4, 2, 3)
self.d4 = nn.ConvTranspose2d(64, self.image_channels, 5, 1, 2, bias=False)
self.img_module = nn.Sequential(self.d1, self.d2, self.d3, self.d4)
# Encoder
def encode(self, x):
self.batch_size = x.size(0)
x = self.e_module(x)
x = x.view(self.batch_size, -1)
mean = self.mu(x)
var = self.sigma(x)
distribution = Normal(mean, var)
return distribution
# Decoder for image denoising
def img_decode(self, z):
self.batch_size = z.size(0)
x = F.relu(self.linear(z))
x = x.view(self.batch_size, 256, 4, 4)
return F.sigmoid(self.img_module(x))
# Forward function
def forward(self, x):
dist = self.encode(x)
if self.training == True:
z = dist.rsample()
else:
z = dist.mean
output = self.img_decode(z)
return output, dist.mean, dist.stddev, z