Skip to content

Commit

Permalink
Note on QM9 pre-processed version + AUROC metric in bipartite GraphSA…
Browse files Browse the repository at this point in the history
…GE example (#6553)
  • Loading branch information
rusty1s authored Jan 31, 2023
1 parent d66498d commit 0b1ecb3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 21 deletions.
28 changes: 7 additions & 21 deletions examples/hetero/bipartite_sage_unsup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
import torch
import torch.nn.functional as F
import tqdm
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
)
from sklearn.metrics import roc_auc_score
from torch.nn import Embedding, Linear

import torch_geometric.transforms as T
Expand Down Expand Up @@ -236,22 +231,13 @@ def test(loader):
pred = torch.cat(preds, dim=0).numpy()
target = torch.cat(targets, dim=0).numpy()

pred = pred > 0.5
acc = accuracy_score(target, pred)
prec = precision_score(target, pred)
rec = recall_score(target, pred)
f1 = f1_score(target, pred)

return acc, prec, rec, f1
return roc_auc_score(target, pred)


for epoch in range(1, 21):
loss = train()
val_acc, val_prec, val_rec, val_f1 = test(val_loader)
test_acc, test_prec, test_rec, test_f1 = test(test_loader)

print(f'Epoch: {epoch:03d}, Loss: {loss:4f}')
print(f'Val Acc: {val_acc:.4f}, Val Precision {val_prec:.4f}, '
f'Val Recall {val_rec:.4f}, Val F1 {val_f1:.4f}')
print(f'Test Acc: {test_acc:.4f}, Test Precision {test_prec:.4f}, '
f'Test Recall {test_rec:.4f}, Test F1 {test_f1:.4f}')
val_auc = test(val_loader)
test_auc = test(test_loader)

print(f'Epoch: {epoch:02d}, Loss: {loss:4f}, Val: {val_auc:.4f}, '
f'Test: {test_auc:.4f}')
6 changes: 6 additions & 0 deletions torch_geometric/datasets/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ class QM9(InMemoryDataset):
| 18 | :math:`C` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
.. note::
We also provide a pre-processed version of the dataset in case
:class:`rdkit` is not installed. The pre-processed version matches with
the manually processed version as outlined in :meth:`process`.
Args:
root (str): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
Expand Down

0 comments on commit 0b1ecb3

Please sign in to comment.