Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ASAPooling jittable #5395

Merged
merged 2 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- `ASAPooling` is now jittable ([#5395](https://github.com/pyg-team/pytorch_geometric/pull/5395))
- Updated unsupervised `GraphSAGE` example to leverage `LinkNeighborLoader` ([#5317](https://github.com/pyg-team/pytorch_geometric/pull/5317))
- Replace in-place operations with out-of-place ones to align with `torch.scatter_reduce` API ([#5353](https://github.com/pyg-team/pytorch_geometric/pull/5353))
- Breaking bugfix: `PointTransformerConv` now correctly uses `sum` aggregation ([#5332](https://github.com/pyg-team/pytorch_geometric/pull/5332))
Expand Down
4 changes: 4 additions & 0 deletions test/nn/pool/test_asap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from torch_geometric.nn import ASAPooling, GCNConv, GraphConv
from torch_geometric.testing import is_full_test


def test_asap():
Expand All @@ -18,6 +19,9 @@ def test_asap():
assert out[0].size() == (num_nodes // 2, in_channels)
assert out[1].size() == (2, 2)

if is_full_test():
torch.jit.script(pool.jittable())

pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN, add_self_loops=True)
assert pool.__repr__() == ('ASAPooling(16, ratio=0.5)')
out = pool(x, edge_index)
Expand Down
45 changes: 35 additions & 10 deletions torch_geometric/nn/pool/asap.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import Callable, Optional, Union
import copy
from typing import Callable, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear
from torch_scatter import scatter
from torch_sparse import SparseTensor
from torch_sparse import (
SparseTensor,
fill_diag,
index_select,
matmul,
remove_diag,
)
from torch_sparse import t as transpose

from torch_geometric.nn import LEConv
from torch_geometric.nn.pool.topk_pool import topk
Expand Down Expand Up @@ -69,16 +78,24 @@ def __init__(self, in_channels: int, ratio: Union[float, int] = 0.5,
if self.GNN is not None:
self.gnn_intra_cluster = GNN(self.in_channels, self.in_channels,
**kwargs)
else:
self.gnn_intra_cluster = None
self.reset_parameters()

def reset_parameters(self):
self.lin.reset_parameters()
self.att.reset_parameters()
self.gnn_score.reset_parameters()
if self.GNN is not None:
if self.gnn_intra_cluster is not None:
self.gnn_intra_cluster.reset_parameters()

def forward(self, x, edge_index, edge_weight=None, batch=None):
def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_weight: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor]:
""""""
N = x.size(0)

Expand All @@ -91,7 +108,7 @@ def forward(self, x, edge_index, edge_weight=None, batch=None):
x = x.unsqueeze(-1) if x.dim() == 1 else x

x_pool = x
if self.GNN is not None:
if self.gnn_intra_cluster is not None:
x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index,
edge_weight=edge_weight)

Expand All @@ -116,24 +133,32 @@ def forward(self, x, edge_index, edge_weight=None, batch=None):
batch = batch[perm]

# Graph coarsening.
row, col = edge_index
row, col = edge_index[0], edge_index[1]
A = SparseTensor(row=row, col=col, value=edge_weight,
sparse_sizes=(N, N))
S = SparseTensor(row=row, col=col, value=score, sparse_sizes=(N, N))
S = S[:, perm]

A = S.t() @ A @ S
S = index_select(S, 1, perm)
A = matmul(matmul(transpose(S), A), S)

if self.add_self_loops:
A = A.fill_diag(1.)
A = fill_diag(A, 1.)
else:
A = A.remove_diag()
A = remove_diag(A)

row, col, edge_weight = A.coo()
edge_index = torch.stack([row, col], dim=0)

return x, edge_index, edge_weight, batch, perm

@torch.jit.unused
def jittable(self) -> 'ASAPooling':
out = copy.deepcopy(self)
out.gnn_score = out.gnn_score.jittable()
if out.gnn_intra_cluster is not None:
out.gnn_intra_cluster = out.gnn_intra_cluster.jittable()
return out

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'ratio={self.ratio})')
20 changes: 13 additions & 7 deletions torch_geometric/nn/pool/topk_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, Optional, Union

import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add, scatter_max

Expand All @@ -10,16 +11,22 @@
from ..inits import uniform


def topk(x, ratio, batch, min_score=None, tol=1e-7):
def topk(
x: Tensor,
ratio: float,
batch: Tensor,
min_score: Optional[int] = None,
tol: float = 1e-7,
) -> Tensor:
if min_score is not None:
# Make sure that we do not drop all nodes in a graph.
scores_max = scatter_max(x, batch)[0].index_select(0, batch) - tol
scores_min = scores_max.clamp(max=min_score)

perm = (x > scores_min).nonzero(as_tuple=False).view(-1)
perm = (x > scores_min).nonzero().view(-1)
else:
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
batch_size, max_num_nodes = num_nodes.size(0), int(num_nodes.max())

cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
Expand All @@ -28,8 +35,7 @@ def topk(x, ratio, batch, min_score=None, tol=1e-7):
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)

dense_x = x.new_full((batch_size * max_num_nodes, ),
torch.finfo(x.dtype).min)
dense_x = x.new_full((batch_size * max_num_nodes, ), -60000.0)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)

Expand All @@ -38,8 +44,8 @@ def topk(x, ratio, batch, min_score=None, tol=1e-7):
perm = perm + cum_num_nodes.view(-1, 1)
perm = perm.view(-1)

if isinstance(ratio, int):
k = num_nodes.new_full((num_nodes.size(0), ), ratio)
if ratio >= 1:
k = num_nodes.new_full((num_nodes.size(0), ), int(ratio))
k = torch.min(k, num_nodes)
else:
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
Expand Down