Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiAggregation and aggregation_resolver #4749

Merged
merged 21 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
12 changes: 9 additions & 3 deletions test/nn/aggr/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
MaxAggregation,
MeanAggregation,
MinAggregation,
MulAggregation,
PowerMeanAggregation,
SoftmaxAggregation,
StdAggregation,
Expand All @@ -29,7 +30,7 @@ def test_validate():

@pytest.mark.parametrize('Aggregation', [
MeanAggregation, SumAggregation, MaxAggregation, MinAggregation,
VarAggregation, StdAggregation
MulAggregation, VarAggregation, StdAggregation
])
def test_basic_aggregation(Aggregation):
x = torch.randn(6, 16)
Expand All @@ -41,7 +42,12 @@ def test_basic_aggregation(Aggregation):

out = aggr(x, index)
assert out.size() == (3, x.size(1))
assert torch.allclose(out, aggr(x, ptr=ptr))

if isinstance(aggr, MulAggregation):
with pytest.raises(NotImplementedError, match="requires 'index'"):
aggr(x, ptr=ptr)
else:
assert torch.allclose(out, aggr(x, ptr=ptr))
lightaime marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize('Aggregation',
Expand All @@ -53,7 +59,7 @@ def test_gen_aggregation(Aggregation, learn):
ptr = torch.tensor([0, 2, 5, 6])

aggr = Aggregation(learn=learn)
assert str(aggr) == f'{Aggregation.__name__}()'
assert str(aggr) == f'{Aggregation.__name__}(learn={learn})'

out = aggr(x, index)
assert out.size() == (3, x.size(1))
Expand Down
21 changes: 21 additions & 0 deletions test/nn/aggr/test_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch

from torch_geometric.nn import MultiAggregation


def test_multi_aggr():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])
ptr = torch.tensor([0, 2, 5, 6])

aggrs = ['mean', 'sum', 'max']
aggr = MultiAggregation(aggrs)
assert str(aggr) == ('MultiAggregation([\n'
' MeanAggregation(),\n'
' SumAggregation(),\n'
' MaxAggregation()\n'
'])')

out = aggr(x, index)
assert out.size() == (3, len(aggrs) * x.size(1))
assert torch.allclose(out, aggr(x, ptr=ptr))
24 changes: 23 additions & 1 deletion test/nn/test_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest
import torch

from torch_geometric.nn.resolver import activation_resolver
import torch_geometric
from torch_geometric.nn.resolver import (
activation_resolver,
aggregation_resolver,
)


def test_activation_resolver():
Expand All @@ -11,3 +16,20 @@ def test_activation_resolver():
assert isinstance(activation_resolver('elu'), torch.nn.ELU)
assert isinstance(activation_resolver('relu'), torch.nn.ReLU)
assert isinstance(activation_resolver('prelu'), torch.nn.PReLU)


@pytest.mark.parametrize('aggr_tuple', [
(torch_geometric.nn.aggr.MeanAggregation, 'mean'),
(torch_geometric.nn.aggr.SumAggregation, 'sum'),
(torch_geometric.nn.aggr.MaxAggregation, 'max'),
(torch_geometric.nn.aggr.MinAggregation, 'min'),
(torch_geometric.nn.aggr.MulAggregation, 'mul'),
(torch_geometric.nn.aggr.VarAggregation, 'var'),
(torch_geometric.nn.aggr.StdAggregation, 'std'),
(torch_geometric.nn.aggr.SoftmaxAggregation, 'softmax'),
(torch_geometric.nn.aggr.PowerMeanAggregation, 'powermean'),
])
def test_aggregation_resolver(aggr_tuple):
aggr_module, aggr_repr = aggr_tuple
assert isinstance(aggregation_resolver(aggr_module()), aggr_module)
assert isinstance(aggregation_resolver(aggr_repr), aggr_module)
6 changes: 6 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from .base import Aggregation
from .multi import MultiAggregation
from .basic import (
MeanAggregation,
SumAggregation,
AddAggregation,
MaxAggregation,
MinAggregation,
MulAggregation,
VarAggregation,
StdAggregation,
SoftmaxAggregation,
Expand All @@ -14,10 +17,13 @@

__all__ = classes = [
'Aggregation',
'MultiAggregation',
'MeanAggregation',
'SumAggregation',
'AddAggregation',
'MaxAggregation',
'MinAggregation',
'MulAggregation',
'VarAggregation',
'StdAggregation',
'SoftmaxAggregation',
Expand Down
21 changes: 21 additions & 0 deletions torch_geometric/nn/aggr/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')


AddAggregation = SumAggregation # Alias


class MaxAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
Expand All @@ -36,6 +39,15 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
return self.reduce(x, index, ptr, dim_size, dim, reduce='min')


class MulAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
# TODO Currently, `mul` reduction can only operate on `index`:
self.assert_index_present(index)
return self.reduce(x, index, None, dim_size, dim, reduce='mul')


class VarAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
Expand All @@ -61,6 +73,7 @@ def __init__(self, t: float = 1.0, learn: bool = False):
super().__init__()
self._init_t = t
self.t = Parameter(torch.Tensor(1)) if learn else t
self.learn = learn
self.reset_parameters()

def reset_parameters(self):
Expand All @@ -77,15 +90,20 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
alpha = softmax(alpha, index, ptr, dim_size, dim)
return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum')

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(learn={self.learn})')


class PowerMeanAggregation(Aggregation):
def __init__(self, p: float = 1.0, learn: bool = False):
# TODO Learn distinct `p` per channel.
super().__init__()
self._init_p = p
self.p = Parameter(torch.Tensor(1)) if learn else p
self.learn = learn
self.reset_parameters()

def reset_parameters(self):
if isinstance(self.p, Tensor):
self.p.data.fill_(self._init_p)

Expand All @@ -97,3 +115,6 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
if isinstance(self.p, (int, float)) and self.p == 1:
return out
return out.clamp_(min=0, max=100).pow(1. / self.p)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(learn={self.learn})')
34 changes: 34 additions & 0 deletions torch_geometric/nn/aggr/multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.resolver import aggregation_resolver


class MultiAggregation(Aggregation):
def __init__(self, aggrs: List[Union[Aggregation, str]]):
super().__init__()

if not isinstance(aggrs, (list, tuple)):
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should "
f"be a list or tuple (got '{type(aggrs)}')")

if len(aggrs) == 0:
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should "
f"not be empty")

self.aggrs = [aggregation_resolver(aggr) for aggr in aggrs]

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
outs = []
for aggr in self.aggrs:
outs.append(aggr(x, index, ptr=ptr, dim_size=dim_size, dim=dim))
return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0]

def __repr__(self) -> str:
lightaime marked this conversation as resolved.
Show resolved Hide resolved
args = [f' {aggr}' for aggr in self.aggrs]
return '{}([\n{}\n])'.format(self.__class__.__name__, ',\n'.join(args))
39 changes: 29 additions & 10 deletions torch_geometric/nn/resolver.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import inspect
from typing import Any, List, Union
from typing import Any, List, Optional, Union

import torch
from torch import Tensor


def normalize_string(s: str) -> str:
return s.lower().replace('-', '').replace('_', '').replace(' ', '')


def resolver(classes: List[Any], query: Union[Any, str], *args, **kwargs):
def resolver(classes: List[Any], query: Union[Any, str],
base_cls: Optional[Any], *args, **kwargs):

if query is None or not isinstance(query, str):
return query

query = normalize_string(query)
query_repr = normalize_string(query)
base_cls_repr = normalize_string(base_cls.__name__) if base_cls else ''
for cls in classes:
if query == normalize_string(cls.__name__):
cls_repr = normalize_string(cls.__name__)
if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]:
if inspect.isclass(cls):
return cls(*args, **kwargs)
else:
return cls

return ValueError(
f"Could not resolve '{query}' among the choices "
f"{set(normalize_string(cls.__name__) for cls in classes)}")
return ValueError(f"Could not resolve '{query}' among the choices "
f"{set(cls.__name__ for cls in classes)}")


# Activation Resolver #########################################################
Expand All @@ -34,11 +36,28 @@ def swish(x: Tensor) -> Tensor:


def activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs):
import torch
base_cls = torch.nn.Module

acts = [
act for act in vars(torch.nn.modules.activation).values()
if isinstance(act, type) and issubclass(act, torch.nn.Module)
if isinstance(act, type) and issubclass(act, base_cls)
]
acts += [
swish,
]
return resolver(acts, query, *args, **kwargs)
return resolver(acts, query, base_cls, *args, **kwargs)


# Aggregation Resolver ########################################################


def aggregation_resolver(query: Union[Any, str], *args, **kwargs):
import torch_geometric.nn.aggr as aggrs
base_cls = aggrs.Aggregation

aggrs = [
aggr for aggr in vars(aggrs).values()
if isinstance(aggr, type) and issubclass(aggr, base_cls)
]
return resolver(aggrs, query, base_cls, *args, **kwargs)