-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
graph_sage_unsup.py
84 lines (67 loc) · 2.21 KB
/
graph_sage_unsup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os.path as osp
import time
import torch
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import GraphSAGE
dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
train_loader = LinkNeighborLoader(
data,
batch_size=256,
shuffle=True,
neg_sampling_ratio=1.0,
num_neighbors=[10, 10],
)
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
data = data.to(device, 'x', 'edge_index')
model = GraphSAGE(
data.num_node_features,
hidden_channels=64,
num_layers=2,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
h = model(batch.x, batch.edge_index)
h_src = h[batch.edge_label_index[0]]
h_dst = h[batch.edge_label_index[1]]
pred = (h_src * h_dst).sum(dim=-1)
loss = F.binary_cross_entropy_with_logits(pred, batch.edge_label)
loss.backward()
optimizer.step()
total_loss += float(loss) * pred.size(0)
return total_loss / data.num_nodes
@torch.no_grad()
def test():
model.eval()
out = model(data.x, data.edge_index).cpu()
clf = LogisticRegression()
clf.fit(out[data.train_mask], data.y[data.train_mask])
val_acc = clf.score(out[data.val_mask], data.y[data.val_mask])
test_acc = clf.score(out[data.test_mask], data.y[data.test_mask])
return val_acc, test_acc
times = []
for epoch in range(1, 51):
start = time.time()
loss = train()
val_acc, test_acc = test()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")