-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
gcn2_ppi.py
102 lines (79 loc) · 3.3 KB
/
gcn2_ppi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os.path as osp
import time
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torch.nn import Linear
import torch_geometric.transforms as T
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCN2Conv
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'GCN2_PPI')
pre_transform = T.Compose([T.GCNNorm(), T.ToSparseTensor()])
train_dataset = PPI(path, split='train', pre_transform=pre_transform)
val_dataset = PPI(path, split='val', pre_transform=pre_transform)
test_dataset = PPI(path, split='test', pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
class Net(torch.nn.Module):
def __init__(self, hidden_channels, num_layers, alpha, theta,
shared_weights=True, dropout=0.0):
super().__init__()
self.lins = torch.nn.ModuleList()
self.lins.append(Linear(train_dataset.num_features, hidden_channels))
self.lins.append(Linear(hidden_channels, train_dataset.num_classes))
self.convs = torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(
GCN2Conv(hidden_channels, alpha, theta, layer + 1,
shared_weights, normalize=False))
self.dropout = dropout
def forward(self, x, adj_t):
x = F.dropout(x, self.dropout, training=self.training)
x = x_0 = self.lins[0](x).relu()
for conv in self.convs:
h = F.dropout(x, self.dropout, training=self.training)
h = conv(h, x_0, adj_t)
x = h + x
x = x.relu()
x = F.dropout(x, self.dropout, training=self.training)
x = self.lins[1](x)
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(hidden_channels=2048, num_layers=9, alpha=0.5, theta=1.0,
shared_weights=False, dropout=0.2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()
def train():
model.train()
total_loss = total_examples = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
loss = criterion(model(data.x, data.adj_t), data.y)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_nodes
total_examples += data.num_nodes
return total_loss / total_examples
@torch.no_grad()
def test(loader):
model.eval()
ys, preds = [], []
for data in loader:
ys.append(data.y)
out = model(data.x.to(device), data.adj_t.to(device))
preds.append((out > 0).float().cpu())
y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()
return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0
times = []
for epoch in range(1, 2001):
start = time.time()
loss = train()
val_f1 = test(val_loader)
test_f1 = test(test_loader)
print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, '
f'Test: {test_f1:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")