Skip to content

Commit

Permalink
Spectral Modularity Pool layer (#4166)
Browse files Browse the repository at this point in the history
* DMonPool Example

* DMonPool Tests

* DMonPool Layer

* Updated __init__.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* DMonPool Tests

* DMonPool Example

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* DMonPool Layer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* DMonPool Layer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* DMonPool Tests

* DMonPool Example

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* DMonPool Example

* DMonPool Tests

* DMonPool Layer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dense_dmon_pool.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use MLP

Co-authored-by: fork123aniket <fork123aniket>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Mar 12, 2022
1 parent 4231111 commit 667e1a8
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 0 deletions.
102 changes: 102 additions & 0 deletions examples/proteins_dmon_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from math import ceil

import torch
import torch.nn.functional as F
from torch.nn import Linear

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DenseGraphConv, DMonPooling, GCNConv
from torch_geometric.utils import to_dense_adj, to_dense_batch

dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS').shuffle()
avg_num_nodes = int(dataset.data.x.size(0) / len(dataset))
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = DataLoader(test_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
train_loader = DataLoader(train_dataset, batch_size=20)


class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels=32):
super().__init__()

self.conv1 = GCNConv(in_channels, hidden_channels)
num_nodes = ceil(0.5 * avg_num_nodes)
self.pool1 = DMonPooling([hidden_channels, hidden_channels], num_nodes)

self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
num_nodes = ceil(0.5 * num_nodes)
self.pool2 = DMonPooling([hidden_channels, hidden_channels], num_nodes)

self.conv3 = DenseGraphConv(hidden_channels, hidden_channels)

self.lin1 = Linear(hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, out_channels)

def forward(self, x, edge_index, batch=None):
x = self.conv1(x, edge_index).relu()
x, mask = to_dense_batch(x, batch)
adj = to_dense_adj(edge_index, batch)

_, x, adj, sp1, o1, c1 = self.pool1(x, adj, mask)

x = self.conv2(x, adj).relu()
_, x, adj, sp2, o2, c2 = self.pool2(x, adj)

x = self.conv3(x, adj)

x = x.mean(dim=1)
x = self.lin1(x).relu()
x = self.lin2(x)
return F.log_softmax(x, dim=-1), sp1 + sp2 + o1 + o2 + c1 + c2


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)


def train(train_loader):
model.train()
loss_all = 0

for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out, tot_loss = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(out, data.y.view(-1)) + tot_loss
loss.backward()
loss_all += data.y.size(0) * loss.item()
optimizer.step()
return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
model.eval()
correct = 0
loss_all = 0

for data in loader:
data = data.to(device)
pred, tot_loss = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(pred, data.y.view(-1)) + tot_loss
loss_all += data.y.size(0) * loss.item()
correct += pred.max(dim=1)[1].eq(data.y.view(-1)).sum().item()

return loss_all / len(loader.dataset), correct / len(loader.dataset)


for epoch in range(100):
train_loss = train(train_loader)
_, train_acc = test(train_loader)
val_loss, val_acc = test(val_loader)
test_loss, test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, '
f'Train Acc: {train_acc:.3f}, Val Loss: {val_loss:.3f}, '
f'Val Acc: {val_acc:.3f}, Test Loss: {test_loss:.3f}, '
f'Test Acc: {test_acc:.3f}')
20 changes: 20 additions & 0 deletions test/nn/dense/test_dense_dmon_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

from torch_geometric.nn import DMonPooling


def test_dense_dmon_pool():
batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10)
x = torch.randn((batch_size, num_nodes, channels))
adj = torch.ones((batch_size, num_nodes, num_nodes))
mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool)

pool = DMonPooling([channels, channels], num_clusters)

s, x, adj, spectral_loss, ortho_loss, cluster_loss = pool(x, adj, mask)
assert s.size() == (2, 20, 10)
assert x.size() == (2, 10, 16)
assert adj.size() == (2, 10, 10)
assert -1 <= spectral_loss <= 0
assert 0 <= ortho_loss <= 2
assert -1 <= cluster_loss <= 0
2 changes: 2 additions & 0 deletions torch_geometric/nn/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .dense_gin_conv import DenseGINConv
from .diff_pool import dense_diff_pool
from .mincut_pool import dense_mincut_pool
from .dense_dmon_pool import DMonPooling

__all__ = [
'Linear',
Expand All @@ -15,6 +16,7 @@
'DenseSAGEConv',
'dense_diff_pool',
'dense_mincut_pool',
'DMonPooling',
]

lin_classes = __all__[:2]
Expand Down
142 changes: 142 additions & 0 deletions torch_geometric/nn/dense/dense_dmon_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import List, Union

import torch
import torch.nn.functional as F

EPS = 1e-15


class DMonPooling(torch.nn.Module):
r"""Spectral modularity pooling operator from the `"Graph Clustering with
Graph Neural Networks" <https://arxiv.org/abs/2006.16904>`_ paper
.. math::
\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{X}
\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})
based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B
\times N \times C}`.
Returns learned cluster assignment matrix, pooled node feature matrix,
coarsened symmetrically normalized adjacency matrix, and two auxiliary
objectives: (1) The spectral loss
.. math::
\mathcal{L}_s = - \frac{1}{2m}
\cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})}
where :math:`\mathbf{B}` is the modularity matrix, (2) the orthogonality
loss
.. math::
\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}}
{{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}}
\right\|}_F
where :math:`C` is the number of clusters, and (3) the cluster loss
.. math::
\mathcal{L}_c = \frac{\sqrt{C}}{n}
{\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1.
.. note::
For an example of using :class:`DMonPooling`, see
`examples/proteins_dmon_pool.py
<https://github.com/pyg-team/pytorch_geometric/blob
/master/examples/proteins_dmon_pool.py>`_.
Args:
channels (List[int]): List of input and intermediate channels in order
to construct an MLP.
k (int): The number of clusters.
dropout (float, optional): Dropout probability. (default: :obj:`0`)
"""
def __init__(self, channels: Union[int, List[int]], k: int,
dropout: float = 0.0):
super().__init__()

if isinstance(channels, int):
channels = [channels]

from torch_geometric.nn.models.mlp import MLP
self.mlp = MLP(channels + [k], act='selu', batch_norm=False)
self.dropout = dropout

self.reset_parameters()

def reset_parameters(self):
self.mlp.reset_parameters()

def forward(self, x, adj, mask=None):
r"""
Args:
x (Tensor): Node feature tensor :math:`\mathbf{X} \in
\mathbb{R}^{B \times N \times F}` with batch-size
:math:`B`, (maximum) number of nodes :math:`N` for each graph,
and feature dimension :math:`F`. Since the cluster assignment
matrix :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`
is being created within this method, the MLP and softmax do
not have to be applied beforehand.
adj (Tensor): Symmetrically normalized adjacency tensor
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
mask (BoolTensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
:class:`Tensor`, :class:`Tensor`, :class:`Tensor`)
"""

x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj

s = self.mlp(x)
s = F.dropout(s, self.dropout, training=self.training)

s = torch.softmax(s, dim=-1)

(batch_size, num_nodes, _), k = x.size(), s.size(-1)

if mask is not None:
mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
x, s = x * mask, s * mask

out = torch.matmul(s.transpose(1, 2), x)
out = F.selu(out)
out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)

degrees = torch.einsum('ijk->ik', adj).transpose(0, 1)
m = torch.einsum('ij->', degrees)

ca = torch.matmul(s.transpose(1, 2), degrees)
cb = torch.matmul(degrees.transpose(0, 1), s)

normalizer = torch.matmul(ca, cb) / 2 / m
decompose = out_adj - normalizer
spectral_loss = -self._rank3_trace(decompose) / 2 / m
spectral_loss = torch.mean(spectral_loss)

# Orthogonality regularization.
ss = torch.matmul(s.transpose(1, 2), s)
i_s = torch.eye(k).type_as(ss)
ortho_loss = torch.norm(
ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -
i_s / torch.norm(i_s), dim=(-1, -2))
ortho_loss = torch.mean(ortho_loss)

cluster_loss = torch.norm(torch.einsum(
'ijk->ij', ss)) / adj.size(1) * torch.norm(i_s) - 1

# Fix and normalize coarsened adjacency matrix.
ind = torch.arange(k, device=out_adj.device)
out_adj[:, ind, ind] = 0
d = torch.einsum('ijk->ij', out_adj)
d = torch.sqrt(d)[:, None] + EPS
out_adj = (out_adj / d) / d.transpose(1, 2)

return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss

def _rank3_trace(self, x):
return torch.einsum('ijj->i', x)

0 comments on commit 667e1a8

Please sign in to comment.