Skip to content

Commit

Permalink
MultiAggregation and aggregation_resolver (#4749)
Browse files Browse the repository at this point in the history
* Add MulAggregation and MultiAggregation

* Fix import issue

* Support torch_geometric.nn.aggr package, note: jit errors to fix

* Add tests for MulAggregation, MultiAggregation, aggregation_resolver and message_passing interface

* Formatting

* Fix __repr for gen aggrs

* Move resolver

* Fix test for MulAggregation

* Add test for new mp interface

* Add test for MultiAggregation

* Minor fix

* Add warming for MulAggregation with 'ptr'

* Resolve aggr to Aggregation module, remove aggrs logic

* changelog

* Fix mul aggregation

* update

* update

* update

* update

* reset

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
lightaime and rusty1s authored Jun 7, 2022
1 parent 893aca5 commit df61109
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
12 changes: 9 additions & 3 deletions test/nn/aggr/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
MaxAggregation,
MeanAggregation,
MinAggregation,
MulAggregation,
PowerMeanAggregation,
SoftmaxAggregation,
StdAggregation,
Expand All @@ -29,7 +30,7 @@ def test_validate():

@pytest.mark.parametrize('Aggregation', [
MeanAggregation, SumAggregation, MaxAggregation, MinAggregation,
VarAggregation, StdAggregation
MulAggregation, VarAggregation, StdAggregation
])
def test_basic_aggregation(Aggregation):
x = torch.randn(6, 16)
Expand All @@ -41,7 +42,12 @@ def test_basic_aggregation(Aggregation):

out = aggr(x, index)
assert out.size() == (3, x.size(1))
assert torch.allclose(out, aggr(x, ptr=ptr))

if isinstance(aggr, MulAggregation):
with pytest.raises(NotImplementedError, match="requires 'index'"):
aggr(x, ptr=ptr)
else:
assert torch.allclose(out, aggr(x, ptr=ptr))


@pytest.mark.parametrize('Aggregation',
Expand All @@ -53,7 +59,7 @@ def test_gen_aggregation(Aggregation, learn):
ptr = torch.tensor([0, 2, 5, 6])

aggr = Aggregation(learn=learn)
assert str(aggr) == f'{Aggregation.__name__}()'
assert str(aggr) == f'{Aggregation.__name__}(learn={learn})'

out = aggr(x, index)
assert out.size() == (3, x.size(1))
Expand Down
21 changes: 21 additions & 0 deletions test/nn/aggr/test_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch

from torch_geometric.nn import MultiAggregation


def test_multi_aggr():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])
ptr = torch.tensor([0, 2, 5, 6])

aggrs = ['mean', 'sum', 'max']
aggr = MultiAggregation(aggrs)
assert str(aggr) == ('MultiAggregation([\n'
' MeanAggregation(),\n'
' SumAggregation(),\n'
' MaxAggregation()\n'
'])')

out = aggr(x, index)
assert out.size() == (3, len(aggrs) * x.size(1))
assert torch.allclose(out, aggr(x, ptr=ptr))
24 changes: 23 additions & 1 deletion test/nn/test_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest
import torch

from torch_geometric.nn.resolver import activation_resolver
import torch_geometric
from torch_geometric.nn.resolver import (
activation_resolver,
aggregation_resolver,
)


def test_activation_resolver():
Expand All @@ -11,3 +16,20 @@ def test_activation_resolver():
assert isinstance(activation_resolver('elu'), torch.nn.ELU)
assert isinstance(activation_resolver('relu'), torch.nn.ReLU)
assert isinstance(activation_resolver('prelu'), torch.nn.PReLU)


@pytest.mark.parametrize('aggr_tuple', [
(torch_geometric.nn.aggr.MeanAggregation, 'mean'),
(torch_geometric.nn.aggr.SumAggregation, 'sum'),
(torch_geometric.nn.aggr.MaxAggregation, 'max'),
(torch_geometric.nn.aggr.MinAggregation, 'min'),
(torch_geometric.nn.aggr.MulAggregation, 'mul'),
(torch_geometric.nn.aggr.VarAggregation, 'var'),
(torch_geometric.nn.aggr.StdAggregation, 'std'),
(torch_geometric.nn.aggr.SoftmaxAggregation, 'softmax'),
(torch_geometric.nn.aggr.PowerMeanAggregation, 'powermean'),
])
def test_aggregation_resolver(aggr_tuple):
aggr_module, aggr_repr = aggr_tuple
assert isinstance(aggregation_resolver(aggr_module()), aggr_module)
assert isinstance(aggregation_resolver(aggr_repr), aggr_module)
6 changes: 6 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from .base import Aggregation
from .multi import MultiAggregation
from .basic import (
MeanAggregation,
SumAggregation,
AddAggregation,
MaxAggregation,
MinAggregation,
MulAggregation,
VarAggregation,
StdAggregation,
SoftmaxAggregation,
Expand All @@ -14,10 +17,13 @@

__all__ = classes = [
'Aggregation',
'MultiAggregation',
'MeanAggregation',
'SumAggregation',
'AddAggregation',
'MaxAggregation',
'MinAggregation',
'MulAggregation',
'VarAggregation',
'StdAggregation',
'SoftmaxAggregation',
Expand Down
21 changes: 21 additions & 0 deletions torch_geometric/nn/aggr/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')


AddAggregation = SumAggregation # Alias


class MaxAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
Expand All @@ -36,6 +39,15 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
return self.reduce(x, index, ptr, dim_size, dim, reduce='min')


class MulAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
# TODO Currently, `mul` reduction can only operate on `index`:
self.assert_index_present(index)
return self.reduce(x, index, None, dim_size, dim, reduce='mul')


class VarAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
Expand All @@ -61,6 +73,7 @@ def __init__(self, t: float = 1.0, learn: bool = False):
super().__init__()
self._init_t = t
self.t = Parameter(torch.Tensor(1)) if learn else t
self.learn = learn
self.reset_parameters()

def reset_parameters(self):
Expand All @@ -77,15 +90,20 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
alpha = softmax(alpha, index, ptr, dim_size, dim)
return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum')

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(learn={self.learn})')


class PowerMeanAggregation(Aggregation):
def __init__(self, p: float = 1.0, learn: bool = False):
# TODO Learn distinct `p` per channel.
super().__init__()
self._init_p = p
self.p = Parameter(torch.Tensor(1)) if learn else p
self.learn = learn
self.reset_parameters()

def reset_parameters(self):
if isinstance(self.p, Tensor):
self.p.data.fill_(self._init_p)

Expand All @@ -97,3 +115,6 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
if isinstance(self.p, (int, float)) and self.p == 1:
return out
return out.clamp_(min=0, max=100).pow(1. / self.p)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(learn={self.learn})')
34 changes: 34 additions & 0 deletions torch_geometric/nn/aggr/multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.resolver import aggregation_resolver


class MultiAggregation(Aggregation):
def __init__(self, aggrs: List[Union[Aggregation, str]]):
super().__init__()

if not isinstance(aggrs, (list, tuple)):
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should "
f"be a list or tuple (got '{type(aggrs)}')")

if len(aggrs) == 0:
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should "
f"not be empty")

self.aggrs = [aggregation_resolver(aggr) for aggr in aggrs]

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
outs = []
for aggr in self.aggrs:
outs.append(aggr(x, index, ptr=ptr, dim_size=dim_size, dim=dim))
return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0]

def __repr__(self) -> str:
args = [f' {aggr}' for aggr in self.aggrs]
return '{}([\n{}\n])'.format(self.__class__.__name__, ',\n'.join(args))
39 changes: 29 additions & 10 deletions torch_geometric/nn/resolver.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import inspect
from typing import Any, List, Union
from typing import Any, List, Optional, Union

import torch
from torch import Tensor


def normalize_string(s: str) -> str:
return s.lower().replace('-', '').replace('_', '').replace(' ', '')


def resolver(classes: List[Any], query: Union[Any, str], *args, **kwargs):
def resolver(classes: List[Any], query: Union[Any, str],
base_cls: Optional[Any], *args, **kwargs):

if query is None or not isinstance(query, str):
return query

query = normalize_string(query)
query_repr = normalize_string(query)
base_cls_repr = normalize_string(base_cls.__name__) if base_cls else ''
for cls in classes:
if query == normalize_string(cls.__name__):
cls_repr = normalize_string(cls.__name__)
if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]:
if inspect.isclass(cls):
return cls(*args, **kwargs)
else:
return cls

return ValueError(
f"Could not resolve '{query}' among the choices "
f"{set(normalize_string(cls.__name__) for cls in classes)}")
return ValueError(f"Could not resolve '{query}' among the choices "
f"{set(cls.__name__ for cls in classes)}")


# Activation Resolver #########################################################
Expand All @@ -34,11 +36,28 @@ def swish(x: Tensor) -> Tensor:


def activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs):
import torch
base_cls = torch.nn.Module

acts = [
act for act in vars(torch.nn.modules.activation).values()
if isinstance(act, type) and issubclass(act, torch.nn.Module)
if isinstance(act, type) and issubclass(act, base_cls)
]
acts += [
swish,
]
return resolver(acts, query, *args, **kwargs)
return resolver(acts, query, base_cls, *args, **kwargs)


# Aggregation Resolver ########################################################


def aggregation_resolver(query: Union[Any, str], *args, **kwargs):
import torch_geometric.nn.aggr as aggrs
base_cls = aggrs.Aggregation

aggrs = [
aggr for aggr in vars(aggrs).values()
if isinstance(aggr, type) and issubclass(aggr, base_cls)
]
return resolver(aggrs, query, base_cls, *args, **kwargs)

0 comments on commit df61109

Please sign in to comment.