diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ade79105676..20fcf89f1fe6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Accelerated sparse tensor conversion routines ([#7042](https://github.com/pyg-team/pytorch_geometric/pull/7042), [#7043](https://github.com/pyg-team/pytorch_geometric/pull/7043)) - Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041)) - Added an optional `batch_size` and `max_num_nodes` arguments to `MemPooling` layer ([#7239](https://github.com/pyg-team/pytorch_geometric/pull/7239)) +- Fixed training issues of the GraphGPS example ([#7377](https://github.com/pyg-team/pytorch_geometric/pull/7377)) ### Removed diff --git a/examples/graph_gps.py b/examples/graph_gps.py index 161148e5026b..1ad7ad7deedb 100644 --- a/examples/graph_gps.py +++ b/examples/graph_gps.py @@ -1,7 +1,15 @@ import os.path as osp import torch -from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential +from torch.nn import ( + BatchNorm1d, + Embedding, + Linear, + ModuleList, + ReLU, + Sequential, +) +from torch.optim.lr_scheduler import ReduceLROnPlateau import torch_geometric.transforms as T from torch_geometric.datasets import ZINC @@ -20,11 +28,12 @@ class GPS(torch.nn.Module): - def __init__(self, channels: int, num_layers: int): + def __init__(self, channels: int, pe_dim: int, num_layers: int): super().__init__() - self.node_emb = Embedding(21, channels) - self.pe_lin = Linear(20, channels) + self.node_emb = Embedding(28, channels - pe_dim) + self.pe_lin = Linear(20, pe_dim) + self.pe_norm = BatchNorm1d(20) self.edge_emb = Embedding(4, channels) self.convs = ModuleList() @@ -37,24 +46,33 @@ def __init__(self, channels: int, num_layers: int): conv = GPSConv(channels, GINEConv(nn), heads=4, attn_dropout=0.5) self.convs.append(conv) - self.lin = Linear(channels, 1) + self.mlp = Sequential( + Linear(channels, channels // 2), + ReLU(), + Linear(channels // 2, channels // 4), + ReLU(), + Linear(channels // 4, 1), + ) def forward(self, x, pe, edge_index, edge_attr, batch): - x = self.node_emb(x.squeeze(-1)) + self.pe_lin(pe) + x_pe = self.pe_norm(pe) + x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1) edge_attr = self.edge_emb(edge_attr) for conv in self.convs: x = conv(x, edge_index, batch, edge_attr=edge_attr) x = global_add_pool(x, batch) - return self.lin(x) + return self.mlp(x) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -model = GPS(channels=64, num_layers=10).to(device) +model = GPS(channels=64, pe_dim=8, num_layers=10).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) +scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, + min_lr=0.00001) -def train(epoch): +def train(): model.train() total_loss = 0 @@ -87,5 +105,6 @@ def test(loader): loss = train(epoch) val_mae = test(val_loader) test_mae = test(test_loader) + scheduler.step(val_mae) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, ' f'Test: {test_mae:.4f}')