-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathautoencoder.py
86 lines (77 loc) · 3.03 KB
/
autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch.nn as nn
from collections import OrderedDict
class AutoEncoder(nn.Module):
def __init__(self, args):
super(AutoEncoder, self).__init__()
self.args = args
self.input_dim = args.input_dim
self.output_dim = self.input_dim
self.hidden_dims = args.hidden_dims
self.hidden_dims.append(args.latent_dim)
self.dims_list = (args.hidden_dims +
args.hidden_dims[:-1][::-1]) # mirrored structure
self.n_layers = len(self.dims_list)
self.latent_dim = args.latent_dim
self.n_clusters = args.n_clusters
# Validation check
assert self.n_layers % 2 > 0
assert self.dims_list[self.n_layers // 2] == self.latent_dim
# Encoder Network
layers = OrderedDict()
for idx, hidden_dim in enumerate(self.hidden_dims):
if idx == 0:
layers.update(
{
'linear0': nn.Linear(self.input_dim, hidden_dim),
'activation0': nn.ReLU()
}
)
else:
layers.update(
{
'linear{}'.format(idx): nn.Linear(
self.hidden_dims[idx-1], hidden_dim),
'activation{}'.format(idx): nn.ReLU(),
'bn{}'.format(idx): nn.BatchNorm1d(
self.hidden_dims[idx])
}
)
self.encoder = nn.Sequential(layers)
# Decoder Network
layers = OrderedDict()
tmp_hidden_dims = self.hidden_dims[::-1]
for idx, hidden_dim in enumerate(tmp_hidden_dims):
if idx == len(tmp_hidden_dims) - 1:
layers.update(
{
'linear{}'.format(idx): nn.Linear(
hidden_dim, self.output_dim),
}
)
else:
layers.update(
{
'linear{}'.format(idx): nn.Linear(
hidden_dim, tmp_hidden_dims[idx+1]),
'activation{}'.format(idx): nn.ReLU(),
'bn{}'.format(idx): nn.BatchNorm1d(
tmp_hidden_dims[idx+1])
}
)
self.decoder = nn.Sequential(layers)
def __repr__(self):
repr_str = '[Structure]: {}-'.format(self.input_dim)
for idx, dim in enumerate(self.dims_list):
repr_str += '{}-'.format(dim)
repr_str += str(self.output_dim) + '\n'
repr_str += '[n_layers]: {}'.format(self.n_layers) + '\n'
repr_str += '[n_clusters]: {}'.format(self.n_clusters) + '\n'
repr_str += '[input_dims]: {}'.format(self.input_dim)
return repr_str
def __str__(self):
return self.__repr__()
def forward(self, X, latent=False):
output = self.encoder(X)
if latent:
return output
return self.decoder(output)