forked from hi-zhenyu/PVC
-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
72 lines (60 loc) · 2.1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import math
import torch
import torch.nn as nn
class PVC(nn.Module):
def __init__(self, arch_list):
super(PVC, self).__init__()
self.view_size = len(arch_list)
self.enc_list = nn.ModuleList()
self.dec_list = nn.ModuleList()
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
self.sigm = nn.Sigmoid()
# network
for view in range(self.view_size):
enc, dec = self.single_ae(arch_list[view])
self.enc_list.append(enc)
self.dec_list.append(dec)
self.dim = arch_list[0][0]
def reset_parameters(self):
stdv = 1. / math.sqrt(self.dim)
self.A.data.uniform_(-stdv, stdv)
self.A.data += torch.eye(self.dim)
def single_ae(self, arch):
# encoder
enc = nn.ModuleList()
for i in range(len(arch)):
if i < len(arch)-1:
enc.append(nn.Linear(arch[i], arch[i+1]))
else:
break
# decoder
arch.reverse()
dec = nn.ModuleList()
for i in range(len(arch)):
if i < len(arch)-1:
dec.append(nn.Linear(arch[i], arch[i+1]))
else:
break
return enc, dec
def forward(self, inputs_list):
encoded_list = []
decoded_list = []
for view in range(self.view_size):
# encoded
encoded = inputs_list[view]
for i, layer in enumerate(self.enc_list[view]):
if i < len(self.enc_list[view]) - 1:
encoded = self.relu(layer(encoded))
else: # the last layer
encoded = layer(encoded)
encoded_list.append(encoded)
# decoded
decoded = encoded
for i, layer in enumerate(self.dec_list[view]):
if i < len(self.dec_list[view]) - 1:
decoded = self.relu(layer(decoded))
else: # the last layer
decoded = layer(decoded)
decoded_list.append(decoded)
return encoded_list, decoded_list