-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Spectral Modularity Pool layer (#4166)
* DMonPool Example * DMonPool Tests * DMonPool Layer * Updated __init__.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DMonPool Tests * DMonPool Example * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DMonPool Layer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DMonPool Layer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DMonPool Tests * DMonPool Example * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DMonPool Example * DMonPool Tests * DMonPool Layer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dense_dmon_pool.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use MLP Co-authored-by: fork123aniket <fork123aniket> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
- Loading branch information
1 parent
4231111
commit 667e1a8
Showing
4 changed files
with
266 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from math import ceil | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.nn import Linear | ||
|
||
from torch_geometric.datasets import TUDataset | ||
from torch_geometric.loader import DataLoader | ||
from torch_geometric.nn import DenseGraphConv, DMonPooling, GCNConv | ||
from torch_geometric.utils import to_dense_adj, to_dense_batch | ||
|
||
dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS').shuffle() | ||
avg_num_nodes = int(dataset.data.x.size(0) / len(dataset)) | ||
n = (len(dataset) + 9) // 10 | ||
test_dataset = dataset[:n] | ||
val_dataset = dataset[n:2 * n] | ||
train_dataset = dataset[2 * n:] | ||
test_loader = DataLoader(test_dataset, batch_size=20) | ||
val_loader = DataLoader(val_dataset, batch_size=20) | ||
train_loader = DataLoader(train_dataset, batch_size=20) | ||
|
||
|
||
class Net(torch.nn.Module): | ||
def __init__(self, in_channels, out_channels, hidden_channels=32): | ||
super().__init__() | ||
|
||
self.conv1 = GCNConv(in_channels, hidden_channels) | ||
num_nodes = ceil(0.5 * avg_num_nodes) | ||
self.pool1 = DMonPooling([hidden_channels, hidden_channels], num_nodes) | ||
|
||
self.conv2 = DenseGraphConv(hidden_channels, hidden_channels) | ||
num_nodes = ceil(0.5 * num_nodes) | ||
self.pool2 = DMonPooling([hidden_channels, hidden_channels], num_nodes) | ||
|
||
self.conv3 = DenseGraphConv(hidden_channels, hidden_channels) | ||
|
||
self.lin1 = Linear(hidden_channels, hidden_channels) | ||
self.lin2 = Linear(hidden_channels, out_channels) | ||
|
||
def forward(self, x, edge_index, batch=None): | ||
x = self.conv1(x, edge_index).relu() | ||
x, mask = to_dense_batch(x, batch) | ||
adj = to_dense_adj(edge_index, batch) | ||
|
||
_, x, adj, sp1, o1, c1 = self.pool1(x, adj, mask) | ||
|
||
x = self.conv2(x, adj).relu() | ||
_, x, adj, sp2, o2, c2 = self.pool2(x, adj) | ||
|
||
x = self.conv3(x, adj) | ||
|
||
x = x.mean(dim=1) | ||
x = self.lin1(x).relu() | ||
x = self.lin2(x) | ||
return F.log_softmax(x, dim=-1), sp1 + sp2 + o1 + o2 + c1 + c2 | ||
|
||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
model = Net(dataset.num_features, dataset.num_classes).to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4) | ||
|
||
|
||
def train(train_loader): | ||
model.train() | ||
loss_all = 0 | ||
|
||
for data in train_loader: | ||
data = data.to(device) | ||
optimizer.zero_grad() | ||
out, tot_loss = model(data.x, data.edge_index, data.batch) | ||
loss = F.nll_loss(out, data.y.view(-1)) + tot_loss | ||
loss.backward() | ||
loss_all += data.y.size(0) * loss.item() | ||
optimizer.step() | ||
return loss_all / len(train_dataset) | ||
|
||
|
||
@torch.no_grad() | ||
def test(loader): | ||
model.eval() | ||
correct = 0 | ||
loss_all = 0 | ||
|
||
for data in loader: | ||
data = data.to(device) | ||
pred, tot_loss = model(data.x, data.edge_index, data.batch) | ||
loss = F.nll_loss(pred, data.y.view(-1)) + tot_loss | ||
loss_all += data.y.size(0) * loss.item() | ||
correct += pred.max(dim=1)[1].eq(data.y.view(-1)).sum().item() | ||
|
||
return loss_all / len(loader.dataset), correct / len(loader.dataset) | ||
|
||
|
||
for epoch in range(100): | ||
train_loss = train(train_loader) | ||
_, train_acc = test(train_loader) | ||
val_loss, val_acc = test(val_loader) | ||
test_loss, test_acc = test(test_loader) | ||
print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, ' | ||
f'Train Acc: {train_acc:.3f}, Val Loss: {val_loss:.3f}, ' | ||
f'Val Acc: {val_acc:.3f}, Test Loss: {test_loss:.3f}, ' | ||
f'Test Acc: {test_acc:.3f}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
|
||
from torch_geometric.nn import DMonPooling | ||
|
||
|
||
def test_dense_dmon_pool(): | ||
batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10) | ||
x = torch.randn((batch_size, num_nodes, channels)) | ||
adj = torch.ones((batch_size, num_nodes, num_nodes)) | ||
mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool) | ||
|
||
pool = DMonPooling([channels, channels], num_clusters) | ||
|
||
s, x, adj, spectral_loss, ortho_loss, cluster_loss = pool(x, adj, mask) | ||
assert s.size() == (2, 20, 10) | ||
assert x.size() == (2, 10, 16) | ||
assert adj.size() == (2, 10, 10) | ||
assert -1 <= spectral_loss <= 0 | ||
assert 0 <= ortho_loss <= 2 | ||
assert -1 <= cluster_loss <= 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from typing import List, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
EPS = 1e-15 | ||
|
||
|
||
class DMonPooling(torch.nn.Module): | ||
r"""Spectral modularity pooling operator from the `"Graph Clustering with | ||
Graph Neural Networks" <https://arxiv.org/abs/2006.16904>`_ paper | ||
.. math:: | ||
\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot | ||
\mathbf{X} | ||
\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot | ||
\mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) | ||
based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B | ||
\times N \times C}`. | ||
Returns learned cluster assignment matrix, pooled node feature matrix, | ||
coarsened symmetrically normalized adjacency matrix, and two auxiliary | ||
objectives: (1) The spectral loss | ||
.. math:: | ||
\mathcal{L}_s = - \frac{1}{2m} | ||
\cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})} | ||
where :math:`\mathbf{B}` is the modularity matrix, (2) the orthogonality | ||
loss | ||
.. math:: | ||
\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} | ||
{{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} | ||
\right\|}_F | ||
where :math:`C` is the number of clusters, and (3) the cluster loss | ||
.. math:: | ||
\mathcal{L}_c = \frac{\sqrt{C}}{n} | ||
{\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1. | ||
.. note:: | ||
For an example of using :class:`DMonPooling`, see | ||
`examples/proteins_dmon_pool.py | ||
<https://github.com/pyg-team/pytorch_geometric/blob | ||
/master/examples/proteins_dmon_pool.py>`_. | ||
Args: | ||
channels (List[int]): List of input and intermediate channels in order | ||
to construct an MLP. | ||
k (int): The number of clusters. | ||
dropout (float, optional): Dropout probability. (default: :obj:`0`) | ||
""" | ||
def __init__(self, channels: Union[int, List[int]], k: int, | ||
dropout: float = 0.0): | ||
super().__init__() | ||
|
||
if isinstance(channels, int): | ||
channels = [channels] | ||
|
||
from torch_geometric.nn.models.mlp import MLP | ||
self.mlp = MLP(channels + [k], act='selu', batch_norm=False) | ||
self.dropout = dropout | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.mlp.reset_parameters() | ||
|
||
def forward(self, x, adj, mask=None): | ||
r""" | ||
Args: | ||
x (Tensor): Node feature tensor :math:`\mathbf{X} \in | ||
\mathbb{R}^{B \times N \times F}` with batch-size | ||
:math:`B`, (maximum) number of nodes :math:`N` for each graph, | ||
and feature dimension :math:`F`. Since the cluster assignment | ||
matrix :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` | ||
is being created within this method, the MLP and softmax do | ||
not have to be applied beforehand. | ||
adj (Tensor): Symmetrically normalized adjacency tensor | ||
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. | ||
mask (BoolTensor, optional): Mask matrix | ||
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating | ||
the valid nodes for each graph. (default: :obj:`None`) | ||
:rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`, | ||
:class:`Tensor`, :class:`Tensor`, :class:`Tensor`) | ||
""" | ||
|
||
x = x.unsqueeze(0) if x.dim() == 2 else x | ||
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj | ||
|
||
s = self.mlp(x) | ||
s = F.dropout(s, self.dropout, training=self.training) | ||
|
||
s = torch.softmax(s, dim=-1) | ||
|
||
(batch_size, num_nodes, _), k = x.size(), s.size(-1) | ||
|
||
if mask is not None: | ||
mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) | ||
x, s = x * mask, s * mask | ||
|
||
out = torch.matmul(s.transpose(1, 2), x) | ||
out = F.selu(out) | ||
out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) | ||
|
||
degrees = torch.einsum('ijk->ik', adj).transpose(0, 1) | ||
m = torch.einsum('ij->', degrees) | ||
|
||
ca = torch.matmul(s.transpose(1, 2), degrees) | ||
cb = torch.matmul(degrees.transpose(0, 1), s) | ||
|
||
normalizer = torch.matmul(ca, cb) / 2 / m | ||
decompose = out_adj - normalizer | ||
spectral_loss = -self._rank3_trace(decompose) / 2 / m | ||
spectral_loss = torch.mean(spectral_loss) | ||
|
||
# Orthogonality regularization. | ||
ss = torch.matmul(s.transpose(1, 2), s) | ||
i_s = torch.eye(k).type_as(ss) | ||
ortho_loss = torch.norm( | ||
ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - | ||
i_s / torch.norm(i_s), dim=(-1, -2)) | ||
ortho_loss = torch.mean(ortho_loss) | ||
|
||
cluster_loss = torch.norm(torch.einsum( | ||
'ijk->ij', ss)) / adj.size(1) * torch.norm(i_s) - 1 | ||
|
||
# Fix and normalize coarsened adjacency matrix. | ||
ind = torch.arange(k, device=out_adj.device) | ||
out_adj[:, ind, ind] = 0 | ||
d = torch.einsum('ijk->ij', out_adj) | ||
d = torch.sqrt(d)[:, None] + EPS | ||
out_adj = (out_adj / d) / d.transpose(1, 2) | ||
|
||
return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss | ||
|
||
def _rank3_trace(self, x): | ||
return torch.einsum('ijj->i', x) |