-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathgat.py
60 lines (49 loc) · 2.12 KB
/
gat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from torch import nn
class GraphAttentionLayer(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
n_heads: int,
is_concat: bool = True,
dropout: float = 0.1,
leaky_relu_negative_slope: float = 0.2,
):
super().__init__()
self.is_concat = is_concat
self.n_heads = n_heads
if is_concat:
assert out_features % n_heads == 0
self.n_hidden = out_features // n_heads
else:
self.n_hidden = out_features
self.proj = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
self.proj_attn = nn.Linear(in_features, self.n_hidden * n_heads // 2, bias=False)
self.attn = nn.Linear(self.n_hidden, 1, bias=False)
self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
self.softmax = nn.Softmax(dim=2)
self.dropout = nn.Dropout(dropout)
def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
batch_size, n_nodes = h.shape[0:2]
g = self.proj(h).view(batch_size, n_nodes, self.n_heads, self.n_hidden)
ga = self.proj_attn(h).view(batch_size, n_nodes, self.n_heads, self.n_hidden // 2)
g_repeat = ga.repeat(1, n_nodes, 1, 1)
g_repeat_interleave = ga.repeat_interleave(n_nodes, dim=1)
g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
g_concat = g_concat.view(
batch_size, n_nodes, n_nodes, self.n_heads, self.n_hidden
)
e = self.activation(self.attn(g_concat))
e = e.squeeze(-1)
assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
e = e.masked_fill(adj_mat < 0.5, float("-inf"))
a = self.softmax(e)
a = self.dropout(a)
attn_res = torch.einsum("bijh,bjhf->bihf", a, g)
if self.is_concat:
return attn_res.reshape(batch_size, n_nodes, self.n_heads * self.n_hidden)
else:
return attn_res.mean(dim=2)