-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
faust.py
82 lines (64 loc) · 2.79 KB
/
faust.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os.path as osp
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import FAUST
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SplineConv
from torch_geometric.typing import WITH_TORCH_SPLINE_CONV
if not WITH_TORCH_SPLINE_CONV:
quit("This example requires 'torch-spline-conv'")
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FAUST')
pre_transform = T.Compose([T.FaceToEdge(), T.Constant(value=1)])
train_dataset = FAUST(path, True, T.Cartesian(), pre_transform)
test_dataset = FAUST(path, False, T.Cartesian(), pre_transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)
d = train_dataset[0]
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5, aggr='add')
self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5, aggr='add')
self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
self.conv4 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
self.conv5 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
self.conv6 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
self.lin1 = torch.nn.Linear(64, 256)
self.lin2 = torch.nn.Linear(256, d.num_nodes)
def forward(self, data):
x, edge_index, pseudo = data.x, data.edge_index, data.edge_attr
x = F.elu(self.conv1(x, edge_index, pseudo))
x = F.elu(self.conv2(x, edge_index, pseudo))
x = F.elu(self.conv3(x, edge_index, pseudo))
x = F.elu(self.conv4(x, edge_index, pseudo))
x = F.elu(self.conv5(x, edge_index, pseudo))
x = F.elu(self.conv6(x, edge_index, pseudo))
x = F.elu(self.lin1(x))
x = F.dropout(x, training=self.training)
x = self.lin2(x)
return F.log_softmax(x, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
target = torch.arange(d.num_nodes, dtype=torch.long, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(epoch):
model.train()
if epoch == 61:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.001
for data in train_loader:
optimizer.zero_grad()
F.nll_loss(model(data.to(device)), target).backward()
optimizer.step()
def test():
model.eval()
correct = 0
for data in test_loader:
pred = model(data.to(device)).max(1)[1]
correct += pred.eq(target).sum().item()
return correct / (len(test_dataset) * d.num_nodes)
for epoch in range(1, 101):
train(epoch)
test_acc = test()
print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}')