Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GraphGPS example #7377

Merged
merged 9 commits into from
May 17, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Accelerated sparse tensor conversion routines ([#7042](https://github.com/pyg-team/pytorch_geometric/pull/7042), [#7043](https://github.com/pyg-team/pytorch_geometric/pull/7043))
- Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041))
- Added an optional `batch_size` and `max_num_nodes` arguments to `MemPooling` layer ([#7239](https://github.com/pyg-team/pytorch_geometric/pull/7239))
- Fixed training issues of the GraphGPS example ([#7377](https://github.com/pyg-team/pytorch_geometric/pull/7377))

### Removed

Expand Down
37 changes: 28 additions & 9 deletions examples/graph_gps.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import os.path as osp

import torch
from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential
from torch.nn import (
BatchNorm1d,
Embedding,
Linear,
ModuleList,
ReLU,
Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
Expand All @@ -20,11 +28,12 @@


class GPS(torch.nn.Module):
def __init__(self, channels: int, num_layers: int):
def __init__(self, channels: int, pe_dim: int, num_layers: int):
super().__init__()

self.node_emb = Embedding(21, channels)
self.pe_lin = Linear(20, channels)
self.node_emb = Embedding(28, channels - pe_dim)
self.pe_lin = Linear(20, pe_dim)
self.pe_norm = BatchNorm1d(20)
self.edge_emb = Embedding(4, channels)

self.convs = ModuleList()
Expand All @@ -37,24 +46,33 @@ def __init__(self, channels: int, num_layers: int):
conv = GPSConv(channels, GINEConv(nn), heads=4, attn_dropout=0.5)
self.convs.append(conv)

self.lin = Linear(channels, 1)
self.mlp = Sequential(
Linear(channels, channels // 2),
ReLU(),
Linear(channels // 2, channels // 4),
ReLU(),
Linear(channels // 4, 1),
)

def forward(self, x, pe, edge_index, edge_attr, batch):
x = self.node_emb(x.squeeze(-1)) + self.pe_lin(pe)
x_pe = self.pe_norm(pe)
x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)
edge_attr = self.edge_emb(edge_attr)

for conv in self.convs:
x = conv(x, edge_index, batch, edge_attr=edge_attr)
x = global_add_pool(x, batch)
return self.lin(x)
return self.mlp(x)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GPS(channels=64, num_layers=10).to(device)
model = GPS(channels=64, pe_dim=8, num_layers=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
min_lr=0.00001)


def train(epoch):
def train():
model.train()

total_loss = 0
Expand Down Expand Up @@ -87,5 +105,6 @@ def test(loader):
loss = train(epoch)
val_mae = test(val_loader)
test_mae = test(test_loader)
scheduler.step(val_mae)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
f'Test: {test_mae:.4f}')