-
Notifications
You must be signed in to change notification settings - Fork 0
/
xugcn.py
239 lines (182 loc) · 9.64 KB
/
xugcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append("models/")
from xumlp import MLP
class GraphCNN(nn.Module):
archName = "GCN Xu"
def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, device):
'''
num_layers: number of layers in the neural networks (INCLUDING the input layer)
num_mlp_layers: number of layers in mlps (EXCLUDING the input layer)
input_dim: dimensionality of input features
hidden_dim: dimensionality of hidden units at ALL layers
output_dim: number of classes for prediction
final_dropout: dropout ratio on the final linear layer
learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether.
neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
device: which device to use
'''
super(GraphCNN, self).__init__()
self.n_epochs = 200
self.final_dropout = final_dropout
self.device = device
self.num_layers = num_layers
self.graph_pooling_type = graph_pooling_type
self.neighbor_pooling_type = neighbor_pooling_type
self.learn_eps = learn_eps
self.eps = nn.Parameter(torch.zeros(self.num_layers-1))
self.output_dim = output_dim
self.hid_channel = hidden_dim
###List of MLPs
self.mlps = torch.nn.ModuleList()
###List of batchnorms applied to the output of MLP (input of the final prediction linear layer)
self.batch_norms = torch.nn.ModuleList()
for layer in range(self.num_layers-1):
if layer == 0:
self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim))
else:
self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim))
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
#Linear function that maps the hidden representation at dofferemt layers into a prediction score
self.linears_prediction = torch.nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.linears_prediction.append(nn.Linear(input_dim, output_dim))
else:
self.linears_prediction.append(nn.Linear(hidden_dim, output_dim))
def __preprocess_neighbors_maxpool(self, batch_graph):
###create padded_neighbor_list in concatenated graph
#compute the maximum number of neighbors within the graphs in the current minibatch
max_deg = max([graph.max_neighbor for graph in batch_graph])
padded_neighbor_list = []
start_idx = [0]
for i, graph in enumerate(batch_graph):
start_idx.append(start_idx[i] + len(graph.g))
padded_neighbors = []
for j in range(len(graph.neighbors)):
#add off-set values to the neighbor indices
pad = [n + start_idx[i] for n in graph.neighbors[j]]
#padding, dummy data is assumed to be stored in -1
pad.extend([-1]*(max_deg - len(pad)))
#Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
if not self.learn_eps:
pad.append(j + start_idx[i])
padded_neighbors.append(pad)
padded_neighbor_list.extend(padded_neighbors)
return torch.LongTensor(padded_neighbor_list)
def __preprocess_neighbors_sumavepool(self, batch_graph):
###create block diagonal sparse matrix
edge_mat_list = []
start_idx = [0]
for i, graph in enumerate(batch_graph):
start_idx.append(start_idx[i] + len(graph.g))
edge_mat_list.append(graph.edge_mat + start_idx[i])
Adj_block_idx = torch.cat(edge_mat_list, 1)
Adj_block_elem = torch.ones(Adj_block_idx.shape[1])
#Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
if not self.learn_eps:
num_node = start_idx[-1]
self_loop_edge = torch.LongTensor([range(num_node), range(num_node)])
elem = torch.ones(num_node)
Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1)
Adj_block_elem = torch.cat([Adj_block_elem, elem], 0)
Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]]))
return Adj_block.to(self.device)
def __preprocess_graphpool(self, batch_graph):
###create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes)
start_idx = [0]
#compute the padded neighbor list
for i, graph in enumerate(batch_graph):
start_idx.append(start_idx[i] + len(graph.g))
idx = []
elem = []
for i, graph in enumerate(batch_graph):
###average pooling
if self.graph_pooling_type == "average":
elem.extend([1./len(graph.g)]*len(graph.g))
else:
###sum pooling
elem.extend([1]*len(graph.g))
idx.extend([[i, j] for j in range(start_idx[i], start_idx[i+1], 1)])
elem = torch.FloatTensor(elem)
idx = torch.LongTensor(idx).transpose(0,1)
graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]]))
return graph_pool.to(self.device)
def maxpool(self, h, padded_neighbor_list):
###Element-wise minimum will never affect max-pooling
dummy = torch.min(h, dim = 0)[0]
h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)])
pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0]
return pooled_rep
def next_layer_eps(self, h, layer, padded_neighbor_list = None, Adj_block = None):
###pooling neighboring nodes and center nodes separately by epsilon reweighting.
if self.neighbor_pooling_type == "max":
##If max pooling
pooled = self.maxpool(h, padded_neighbor_list)
else:
#If sum or average pooling
pooled = torch.spmm(Adj_block, h)
if self.neighbor_pooling_type == "average":
#If average pooling
degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
pooled = pooled/degree
#Reweights the center node representation when aggregating it with its neighbors
pooled = pooled + (1 + self.eps[layer])*h
pooled_rep = self.mlps[layer](pooled)
h = self.batch_norms[layer](pooled_rep)
#non-linearity
h = F.relu(h)
return h
def next_layer(self, h, layer, padded_neighbor_list = None, Adj_block = None):
###pooling neighboring nodes and center nodes altogether
if self.neighbor_pooling_type == "max":
##If max pooling
pooled = self.maxpool(h, padded_neighbor_list)
else:
#If sum or average pooling
pooled = torch.spmm(Adj_block, h)
if self.neighbor_pooling_type == "average":
#If average pooling
degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
pooled = pooled/degree
#representation of neighboring and center nodes
pooled_rep = self.mlps[layer](pooled)
h = self.batch_norms[layer](pooled_rep)
#non-linearity
h = F.relu(h)
return h
def forward(self, batch_graph):
X_concat = batch_graph[0].to(self.device)
batch_graph = batch_graph[-1]
#X_concat = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device)
graph_pool = self.__preprocess_graphpool(batch_graph)
if self.neighbor_pooling_type == "max":
padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph)
else:
Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph)
#list of hidden representation at each layer (including input)
hidden_rep = [X_concat]
h = X_concat
for layer in range(self.num_layers-1):
if self.neighbor_pooling_type == "max" and self.learn_eps:
h = self.next_layer_eps(h, layer, padded_neighbor_list = padded_neighbor_list)
elif not self.neighbor_pooling_type == "max" and self.learn_eps:
h = self.next_layer_eps(h, layer, Adj_block = Adj_block)
elif self.neighbor_pooling_type == "max" and not self.learn_eps:
h = self.next_layer(h, layer, padded_neighbor_list = padded_neighbor_list)
elif not self.neighbor_pooling_type == "max" and not self.learn_eps:
h = self.next_layer(h, layer, Adj_block = Adj_block)
hidden_rep.append(h)
score_over_layer = 0
#perform pooling over all nodes in each graph in every layer
for layer, h in enumerate(hidden_rep):
pooled_h = torch.spmm(graph_pool, h)
score_over_layer += F.dropout(self.linears_prediction[layer](pooled_h), self.final_dropout, training = self.training)
if self.output_dim == 1: #binary
x = torch.sigmoid(score_over_layer)
return torch.flatten(x)
else:
return torch.nn.functional.log_softmax(score_over_layer, dim=1)