Skip to content

Commit

Permalink
[Type Hints] nn.PANPooling (#5852)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Oct 29, 2022
1 parent ec7c49b commit 62e2e52
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5842](https://github.com/pyg-team/pytorch_geometric/pull/5842), [#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5731](https://github.com/pyg-team/pytorch_geometric/pull/5731), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5738](https://github.com/pyg-team/pytorch_geometric/pull/5738), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768)), [#5781](https://github.com/pyg-team/pytorch_geometric/pull/5781), [#5778](https://github.com/pyg-team/pytorch_geometric/pull/5778), [#5797](https://github.com/pyg-team/pytorch_geometric/pull/5797), [#5798](https://github.com/pyg-team/pytorch_geometric/pull/5798), [#5799](https://github.com/pyg-team/pytorch_geometric/pull/5799), [#5800](https://github.com/pyg-team/pytorch_geometric/pull/5800), [#5806](https://github.com/pyg-team/pytorch_geometric/pull/5806), [#5810](https://github.com/pyg-team/pytorch_geometric/pull/5810), [#5811](https://github.com/pyg-team/pytorch_geometric/pull/5811), [#5828](https://github.com/pyg-team/pytorch_geometric/pull/5828), [#5847](https://github.com/pyg-team/pytorch_geometric/pull/5847), [#5851](https://github.com/pyg-team/pytorch_geometric/pull/5851))
- Improved type hint support ([#5842](https://github.com/pyg-team/pytorch_geometric/pull/5842), [#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5731](https://github.com/pyg-team/pytorch_geometric/pull/5731), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5738](https://github.com/pyg-team/pytorch_geometric/pull/5738), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768)), [#5781](https://github.com/pyg-team/pytorch_geometric/pull/5781), [#5778](https://github.com/pyg-team/pytorch_geometric/pull/5778), [#5797](https://github.com/pyg-team/pytorch_geometric/pull/5797), [#5798](https://github.com/pyg-team/pytorch_geometric/pull/5798), [#5799](https://github.com/pyg-team/pytorch_geometric/pull/5799), [#5800](https://github.com/pyg-team/pytorch_geometric/pull/5800), [#5806](https://github.com/pyg-team/pytorch_geometric/pull/5806), [#5810](https://github.com/pyg-team/pytorch_geometric/pull/5810), [#5811](https://github.com/pyg-team/pytorch_geometric/pull/5811), [#5828](https://github.com/pyg-team/pytorch_geometric/pull/5828), [#5847](https://github.com/pyg-team/pytorch_geometric/pull/5847), [#5851](https://github.com/pyg-team/pytorch_geometric/pull/5851), [#5852](https://github.com/pyg-team/pytorch_geometric/pull/5852))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
17 changes: 14 additions & 3 deletions test/nn/pool/test_pan_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from torch_geometric.nn import PANConv, PANPooling
from torch_geometric.testing import is_full_test


def test_pan_pooling():
Expand All @@ -11,13 +12,23 @@ def test_pan_pooling():

conv = PANConv(16, 32, filter_size=2)
pool = PANPooling(32, ratio=0.5)
assert pool.__repr__() == 'PANPooling(32, ratio=0.5, multiplier=1.0)'
assert str(pool) == 'PANPooling(32, ratio=0.5, multiplier=1.0)'

x, M = conv(x, edge_index)
x, edge_index, edge_weight, batch, perm, score = pool(x, M)
h, edge_index, edge_weight, batch, perm, score = pool(x, M)

assert x.size() == (2, 32)
assert h.size() == (2, 32)
assert edge_index.size() == (2, 4)
assert edge_weight.size() == (4, )
assert perm.size() == (2, )
assert score.size() == (2, )

if is_full_test():
jit = torch.jit.script(pool)
out = jit(x, M)
assert torch.allclose(h, out[0])
assert torch.equal(edge_index, out[1])
assert torch.allclose(edge_weight, out[2])
assert torch.equal(batch, out[3])
assert torch.equal(perm, out[4])
assert torch.allclose(score, out[5])
27 changes: 21 additions & 6 deletions torch_geometric/nn/pool/pan_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Parameter
Expand Down Expand Up @@ -38,8 +40,14 @@ class PANPooling(torch.nn.Module):
nonlinearity (torch.nn.functional, optional): The nonlinearity to use.
(default: :obj:`torch.tanh`)
"""
def __init__(self, in_channels, ratio=0.5, min_score=None, multiplier=1.0,
nonlinearity=torch.tanh):
def __init__(
self,
in_channels: int,
ratio: float = 0.5,
min_score: Optional[float] = None,
multiplier: float = 1.0,
nonlinearity: Callable = torch.tanh,
):
super().__init__()

self.in_channels = in_channels
Expand All @@ -57,12 +65,18 @@ def reset_parameters(self):
self.p.data.fill_(1)
self.beta.data.fill_(0.5)

def forward(self, x: Tensor, M: SparseTensor, batch: OptTensor = None):
def forward(
self,
x: Tensor,
M: SparseTensor,
batch: OptTensor = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
""""""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)

row, col, edge_weight = M.coo()
assert edge_weight is not None

score1 = (x * self.p).sum(dim=-1)
score2 = scatter_add(edge_weight, col, dim=0, dim_size=x.size(0))
Expand All @@ -78,10 +92,11 @@ def forward(self, x: Tensor, M: SparseTensor, batch: OptTensor = None):
x = self.multiplier * x if self.multiplier != 1 else x

edge_index = torch.stack([col, row], dim=0)
edge_index, edge_attr = filter_adj(edge_index, edge_weight, perm,
num_nodes=score.size(0))
edge_index, edge_weight = filter_adj(edge_index, edge_weight, perm,
num_nodes=score.size(0))
assert edge_weight is not None

return x, edge_index, edge_attr, batch[perm], perm, score[perm]
return x, edge_index, edge_weight, batch[perm], perm, score[perm]

def __repr__(self) -> str:
if self.min_score is None:
Expand Down

0 comments on commit 62e2e52

Please sign in to comment.