-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgraph_learner.py
103 lines (86 loc) · 3.43 KB
/
graph_learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch.nn as nn
import torch
import torch.nn.functional as F
EOS = 1e-10
class Attentive(nn.Module):
def __init__(self, isize):
super(Attentive, self).__init__()
self.w = nn.Parameter(torch.ones(isize))
def forward(self, x):
return x @ torch.diag(self.w)
class ATT_learner(nn.Module):
def __init__(self, nlayers, isize, k, knn_metric, i, sparse, mlp_act):
super(ATT_learner, self).__init__()
self.i = i
self.layers = nn.ModuleList()
for _ in range(nlayers):
self.layers.append(Attentive(isize))
self.k = k
self.knn_metric = knn_metric
self.non_linearity = 'relu'
self.sparse = sparse
self.mlp_act = mlp_act
def internal_forward(self, h):
for i, layer in enumerate(self.layers):
h = layer(h)
if i != (len(self.layers) - 1):
if self.mlp_act == "relu":
h = F.relu(h)
elif self.mlp_act == "tanh":
h = F.tanh(h)
return h
def forward(self, features):
embeddings = self.internal_forward(features)
embeddings = F.normalize(embeddings, dim=1, p=2)
similarities = cal_similarity_graph(embeddings)
similarities = top_k(similarities, self.k + 1)
similarities = apply_non_linearity(similarities, self.non_linearity, self.i)
learned_adj = symmetrize(similarities)
learned_adj = normalize(learned_adj, 'sym', sparse=self.sparse)
return learned_adj
def apply_non_linearity(tensor, non_linearity, i):
if non_linearity == 'elu':
return F.elu(tensor * i - i) + 1
elif non_linearity == 'relu':
return F.relu(tensor)
elif non_linearity == 'none':
return tensor
else:
raise NameError('We dont support the non-linearity yet')
def cal_similarity_graph(node_embeddings):
similarity_graph = torch.mm(node_embeddings, node_embeddings.t())
return similarity_graph
def top_k(raw_graph, K):
values, indices = raw_graph.topk(k=int(K), dim=-1)
assert torch.max(indices) < raw_graph.shape[1]
mask = torch.zeros(raw_graph.shape).cuda()
mask[torch.arange(raw_graph.shape[0]).view(-1, 1), indices] = 1.
mask.requires_grad = False
sparse_graph = raw_graph * mask
return sparse_graph
def symmetrize(adj): # only for non-sparse
return (adj + adj.T) / 2
def normalize(adj, mode, sparse=False):
if not sparse:
if mode == "sym":
inv_sqrt_degree = 1. / (torch.sqrt(adj.sum(dim=1, keepdim=False)) + EOS)
return inv_sqrt_degree[:, None] * adj * inv_sqrt_degree[None, :]
elif mode == "row":
inv_degree = 1. / (adj.sum(dim=1, keepdim=False) + EOS)
return inv_degree[:, None] * adj
else:
exit("wrong norm mode")
else:
adj = adj.coalesce()
if mode == "sym":
inv_sqrt_degree = 1. / (torch.sqrt(torch.sparse.sum(adj, dim=1).values()))
D_value = inv_sqrt_degree[adj.indices()[0]] * inv_sqrt_degree[adj.indices()[1]]
elif mode == "row":
aa = torch.sparse.sum(adj, dim=1)
bb = aa.values()
inv_degree = 1. / (torch.sparse.sum(adj, dim=1).values() + EOS)
D_value = inv_degree[adj.indices()[0]]
else:
exit("wrong norm mode")
new_values = adj.values() * D_value
return torch.sparse.FloatTensor(adj.indices(), new_values, adj.size())