-
Notifications
You must be signed in to change notification settings - Fork 31
/
gat.py
118 lines (82 loc) · 4.48 KB
/
gat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# @Time : 2020/8/25
# @Author : LeronQ
# @github : https://github.com/LeronQ
# gat.py
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphAttentionLayer(nn.Module):
def __init__(self, in_c, out_c):
super(GraphAttentionLayer, self).__init__()
self.in_c = in_c
self.out_c = out_c
self.F = F.softmax
self.W = nn.Linear(in_c, out_c, bias=False) # y = W * x
self.b = nn.Parameter(torch.Tensor(out_c))
nn.init.normal_(self.W.weight)
nn.init.normal_(self.b)
def forward(self, inputs, graph):
"""
:param inputs: input features, [B, N, C].
:param graph: graph structure, [N, N].
:return:
output features, [B, N, D].
"""
h = self.W(inputs) # [B, N, D],一个线性层,就是第一步中公式的 W*h
# 下面这个就是,第i个节点和第j个节点之间的特征做了一个内积,表示它们特征之间的关联强度
# 再用graph也就是邻接矩阵相乘,因为邻接矩阵用0-1表示,0就表示两个节点之间没有边相连
# 那么最终结果中的0就表示节点之间没有边相连
outputs = torch.bmm(h, h.transpose(1, 2)) * graph.unsqueeze(
0) # [B, N, D]*[B, D, N]->[B, N, N], x(i)^T * x(j)
# 由于上面计算的结果中0表示节点之间没关系,所以将这些0换成负无穷大,因为softmax的负无穷大=0
outputs.data.masked_fill_(torch.eq(outputs, 0), -float(1e16))
attention = self.F(outputs, dim=2) # [B, N, N],在第2维做归一化,就是说所有有边相连的节点做一个归一化,得到了注意力系数
return torch.bmm(attention, h) + self.b # [B, N, N] * [B, N, D],,这个是第三步的,利用注意力系数对邻域节点进行有区别的信息聚合
class GATSubNet(nn.Module): # 这个是多头注意力机制
def __init__(self, in_c, hid_c, out_c, n_heads):
super(GATSubNet, self).__init__()
# 用循环来增加多注意力, 用nn.ModuleList变成一个大的并行的网络
self.attention_module = nn.ModuleList(
[GraphAttentionLayer(in_c, hid_c) for _ in range(n_heads)]) # in_c为输入特征维度,hid_c为隐藏层特征维度
# 上面的多头注意力都得到了不一样的结果,使用注意力层给聚合起来
self.out_att = GraphAttentionLayer(hid_c * n_heads, out_c)
self.act = nn.LeakyReLU()
def forward(self, inputs, graph):
"""
:param inputs: [B, N, C]
:param graph: [N, N]
:return:
"""
# 每一个注意力头用循环取出来,放入list里,然后在最后一维串联起来
outputs = torch.cat([attn(inputs, graph) for attn in self.attention_module], dim=-1) # [B, N, hid_c * h_head]
outputs = self.act(outputs)
outputs = self.out_att(outputs, graph)
return self.act(outputs)
class GATNet(nn.Module):
def __init__(self, in_c, hid_c, out_c, n_heads):
super(GATNet, self).__init__()
self.subnet = GATSubNet(in_c, hid_c, out_c, n_heads)
def forward(self, data, device):
graph = data["graph"][0].to(device) # [N, N]
flow = data["flow_x"] # [B, N, T, C]
flow = flow.to(device) # 将流量数据送入设备
B, N = flow.size(0), flow.size(1)
flow = flow.view(B, N, -1) # [B, N, T * C]
"""
上面是将这一段的时间的特征数据摊平做为特征,这种做法实际上忽略了时序上的连续性
这种做法可行,但是比较粗糙,当然也可以这么做:
flow[:, :, 0] ... flow[:, :, T-1] 则就有T个[B, N, C]这样的张量,也就是 [B, N, C]*T
每一个张量都用一个SubNet来表示,则一共有T个SubNet,初始化定义 self.subnet = [GATSubNet(...) for _ in range(T)]
然后用nn.ModuleList将SubNet分别拎出来处理,参考多头注意力的处理,同理
"""
prediction = self.subnet(flow, graph).unsqueeze(2) # [B, N, 1, C],这个1加上就表示预测的是未来一个时刻
return prediction
if __name__ == '__main__': # 测试模型是否合适
x = torch.randn(32, 278, 6, 2) # [B, N, T, C]
graph = torch.randn(32, 278, 278) # [N, N]
data = {"flow_x": x, "graph": graph}
device = torch.device("cpu")
net = GATNet(in_c=6 * 2, hid_c=6, out_c=2, n_heads=2)
y = net(data, device)
print(y.size())