-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
tgn.py
195 lines (153 loc) · 6.47 KB
/
tgn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# This code achieves a performance of around 96.60%. However, it is not
# directly comparable to the results reported by the TGN paper since a
# slightly different evaluation setup is used here.
# In particular, predictions in the same batch are made in parallel, i.e.
# predictions for interactions later in the batch have no access to any
# information whatsoever about previous interactions in the same batch.
# On the contrary, when sampling node neighborhoods for interactions later in
# the batch, the TGN paper code has access to previous interactions in the
# batch.
# While both approaches are correct, together with the authors of the paper we
# decided to present this version here as it is more realsitic and a better
# test bed for future methods.
import os.path as osp
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (
IdentityMessage,
LastAggregator,
LastNeighborLoader,
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE')
dataset = JODIEDataset(path, name='wikipedia')
data = dataset[0]
# For small datasets, we can put the whole dataset on GPU and thus avoid
# expensive memory transfer costs for mini-batches:
data = data.to(device)
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15)
train_loader = TemporalDataLoader(
train_data,
batch_size=200,
neg_sampling_ratio=1.0,
)
val_loader = TemporalDataLoader(
val_data,
batch_size=200,
neg_sampling_ratio=1.0,
)
test_loader = TemporalDataLoader(
test_data,
batch_size=200,
neg_sampling_ratio=1.0,
)
neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)
class GraphAttentionEmbedding(torch.nn.Module):
def __init__(self, in_channels, out_channels, msg_dim, time_enc):
super().__init__()
self.time_enc = time_enc
edge_dim = msg_dim + time_enc.out_channels
self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,
dropout=0.1, edge_dim=edge_dim)
def forward(self, x, last_update, edge_index, t, msg):
rel_t = last_update[edge_index[0]] - t
rel_t_enc = self.time_enc(rel_t.to(x.dtype))
edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
return self.conv(x, edge_index, edge_attr)
class LinkPredictor(torch.nn.Module):
def __init__(self, in_channels):
super().__init__()
self.lin_src = Linear(in_channels, in_channels)
self.lin_dst = Linear(in_channels, in_channels)
self.lin_final = Linear(in_channels, 1)
def forward(self, z_src, z_dst):
h = self.lin_src(z_src) + self.lin_dst(z_dst)
h = h.relu()
return self.lin_final(h)
memory_dim = time_dim = embedding_dim = 100
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
memory_dim,
time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=memory_dim,
out_channels=embedding_dim,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=embedding_dim).to(device)
optimizer = torch.optim.Adam(
set(memory.parameters()) | set(gnn.parameters())
| set(link_pred.parameters()), lr=0.0001)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
def train():
memory.train()
gnn.train()
link_pred.train()
memory.reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
batch = batch.to(device)
n_id, edge_index, e_id = neighbor_loader(batch.n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))
pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
neighbor_loader.insert(batch.src, batch.dst)
loss.backward()
optimizer.step()
memory.detach()
total_loss += float(loss) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader):
memory.eval()
gnn.eval()
link_pred.eval()
torch.manual_seed(12345) # Ensure deterministic sampling across epochs.
aps, aucs = [], []
for batch in loader:
batch = batch.to(device)
n_id, edge_index, e_id = neighbor_loader(batch.n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))
pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])
y_pred = torch.cat([pos_out, neg_out], dim=0).sigmoid().cpu()
y_true = torch.cat(
[torch.ones(pos_out.size(0)),
torch.zeros(neg_out.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred))
aucs.append(roc_auc_score(y_true, y_pred))
memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
neighbor_loader.insert(batch.src, batch.dst)
return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())
for epoch in range(1, 51):
loss = train()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
val_ap, val_auc = test(val_loader)
test_ap, test_auc = test(test_loader)
print(f'Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')
print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}')