-
Notifications
You must be signed in to change notification settings - Fork 16
/
gnn.py
80 lines (63 loc) · 2.79 KB
/
gnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn.functional as F
import torch.nn as nn
from utils import *
import time
from multiprocessing import Pool
from layers import GAT_gate
N_atom_features = 28
class gnn(torch.nn.Module):
def __init__(self, args):
super(gnn, self).__init__()
n_graph_layer = args.n_graph_layer
d_graph_layer = args.d_graph_layer
n_FC_layer = args.n_FC_layer
d_FC_layer = args.d_FC_layer
self.dropout_rate = args.dropout_rate
self.layers1 = [d_graph_layer for i in range(n_graph_layer+1)]
self.gconv1 = nn.ModuleList([GAT_gate(self.layers1[i], self.layers1[i+1]) for i in range(len(self.layers1)-1)])
self.FC = nn.ModuleList([nn.Linear(self.layers1[-1], d_FC_layer) if i==0 else
nn.Linear(d_FC_layer, 1) if i==n_FC_layer-1 else
nn.Linear(d_FC_layer, d_FC_layer) for i in range(n_FC_layer)])
self.mu = nn.Parameter(torch.Tensor([args.initial_mu]).float())
self.dev = nn.Parameter(torch.Tensor([args.initial_dev]).float())
self.embede = nn.Linear(2*N_atom_features, d_graph_layer, bias = False)
def embede_graph(self, data):
c_hs, c_adjs1, c_adjs2, c_valid = data
c_hs = self.embede(c_hs)
hs_size = c_hs.size()
c_adjs2 = torch.exp(-torch.pow(c_adjs2-self.mu.expand_as(c_adjs2), 2)/self.dev) + c_adjs1
regularization = torch.empty(len(self.gconv1), device=c_hs.device)
for k in range(len(self.gconv1)):
c_hs1 = self.gconv1[k](c_hs, c_adjs1)
c_hs2 = self.gconv1[k](c_hs, c_adjs2)
c_hs = c_hs2-c_hs1
c_hs = F.dropout(c_hs, p=self.dropout_rate, training=self.training)
c_hs = c_hs*c_valid.unsqueeze(-1).repeat(1, 1, c_hs.size(-1))
c_hs = c_hs.sum(1)
return c_hs
def fully_connected(self, c_hs):
regularization = torch.empty(len(self.FC)*1-1, device=c_hs.device)
for k in range(len(self.FC)):
#c_hs = self.FC[k](c_hs)
if k<len(self.FC)-1:
c_hs = self.FC[k](c_hs)
c_hs = F.dropout(c_hs, p=self.dropout_rate, training=self.training)
c_hs = F.relu(c_hs)
else:
c_hs = self.FC[k](c_hs)
c_hs = torch.sigmoid(c_hs)
return c_hs
def train_model(self, data):
#embede a graph to a vector
c_hs = self.embede_graph(data)
#fully connected NN
c_hs = self.fully_connected(c_hs)
c_hs = c_hs.view(-1)
#note that if you don't use concrete dropout, regularization 1-2 is zero
return c_hs
def test_model(self,data1 ):
c_hs = self.embede_graph(data1)
c_hs = self.fully_connected(c_hs)
c_hs = c_hs.view(-1)
return c_hs