Skip to content

Commit

Permalink
torch_geometric.nn.aggr package with base class (#4687)
Browse files Browse the repository at this point in the history
* initial commit

* update

* changelog

* Added basic aggrs, gen aggrs and pna aggrs

* Formatted

* Formatted

* Added test for aggr class

* Formatted

* update

* update

* update

* update

* update

* docstring

* typo

Co-authored-by: lightaime <lightaime@gmail.com>
  • Loading branch information
rusty1s and lightaime authored May 25, 2022
1 parent efffdc3 commit cb92831
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715))
Expand Down
51 changes: 51 additions & 0 deletions test/nn/aggr/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import torch

from torch_geometric.nn import (
MaxAggregation,
MeanAggregation,
MinAggregation,
PowerMeanAggregation,
SoftmaxAggregation,
StdAggregation,
SumAggregation,
VarAggregation,
)


@pytest.mark.parametrize('Aggregation', [
MeanAggregation, SumAggregation, MaxAggregation, MinAggregation,
VarAggregation, StdAggregation
])
def test_basic_aggregation(Aggregation):
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])
ptr = torch.tensor([0, 2, 5, 6])

aggr = Aggregation()
assert str(aggr) == f'{Aggregation.__name__}()'

out = aggr(x, index)
assert out.size() == (3, x.size(1))
assert torch.allclose(out, aggr(x, ptr=ptr))


@pytest.mark.parametrize('Aggregation',
[SoftmaxAggregation, PowerMeanAggregation])
@pytest.mark.parametrize('learn', [True, False])
def test_gen_aggregation(Aggregation, learn):
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])
ptr = torch.tensor([0, 2, 5, 6])

aggr = Aggregation(learn=learn)
assert str(aggr) == f'{Aggregation.__name__}()'

out = aggr(x, index)
assert out.size() == (3, x.size(1))
assert torch.allclose(out, aggr(x, ptr=ptr))

if learn:
out.mean().backward()
for param in aggr.parameters():
assert not torch.isnan(param.grad).any()
1 change: 1 addition & 0 deletions torch_geometric/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .data_parallel import DataParallel
from .to_hetero_transformer import to_hetero
from .to_hetero_with_bases_transformer import to_hetero_with_bases
from .aggr import * # noqa
from .conv import * # noqa
from .norm import * # noqa
from .glob import * # noqa
Expand Down
23 changes: 23 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .base import Aggregation
from .basic import (
MeanAggregation,
SumAggregation,
MaxAggregation,
MinAggregation,
VarAggregation,
StdAggregation,
SoftmaxAggregation,
PowerMeanAggregation,
)

__all__ = classes = [
'Aggregation',
'MeanAggregation',
'SumAggregation',
'MaxAggregation',
'MinAggregation',
'VarAggregation',
'StdAggregation',
'SoftmaxAggregation',
'PowerMeanAggregation',
]
62 changes: 62 additions & 0 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter, segment_csr


class Aggregation(torch.nn.Module, ABC):
r"""An abstract base class for implementing custom aggregations."""
@abstractmethod
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
r"""
Args:
x (torch.Tensor): The source tensor.
index (torch.LongTensor, optional): The indices of elements for
applying the aggregation.
One of :obj:`index` or `ptr` must be defined.
(default: :obj:`None`)
ptr (torch.LongTensor, optional): If given, computes the
aggregation based on sorted inputs in CSR representation.
One of :obj:`index` or `ptr` must be defined.
(default: :obj:`None`)
dim_size (int, optional): The size of the output tensor at
dimension :obj:`dim` after aggregation. (default: :obj:`None`)
dim (int, optional): The dimension in which to aggregate.
(default: :obj:`-2`)
"""
pass

def reset_parameters(self):
pass

def reduce(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2, reduce: str = 'add') -> Tensor:

assert index is not None or ptr is not None

if ptr is not None:
ptr = expand_left(ptr, dim, dims=x.dim())
return segment_csr(x, ptr, reduce=reduce)

if index is not None:
return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)

raise ValueError(f"Error in '{self.__class__.__name__}': "
f"One of 'index' or 'ptr' must be defined")

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'


###############################################################################


def expand_left(ptr: Tensor, dim: int, dims: int) -> Tensor:
for _ in range(dims + dim if dim < 0 else dim):
ptr = ptr.unsqueeze(0)
return ptr
99 changes: 99 additions & 0 deletions torch_geometric/nn/aggr/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import softmax


class MeanAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
return self.reduce(x, index, ptr, dim_size, dim, reduce='mean')


class SumAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')


class MaxAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
return self.reduce(x, index, ptr, dim_size, dim, reduce='max')


class MinAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
return self.reduce(x, index, ptr, dim_size, dim, reduce='min')


class VarAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

mean = self.reduce(x, index, ptr, dim_size, dim, reduce='mean')
mean_2 = self.reduce(x * x, index, ptr, dim_size, dim, reduce='mean')
return mean_2 - mean * mean


class StdAggregation(VarAggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

var = super().forward(x, index, ptr=ptr, dim_size=dim_size, dim=dim)
return torch.sqrt(var.relu() + 1e-5)


class SoftmaxAggregation(Aggregation):
def __init__(self, t: float = 1.0, learn: bool = False):
# TODO Learn distinct `t` per channel.
super().__init__()
self._init_t = t
self.t = Parameter(torch.Tensor(1)) if learn else t
self.reset_parameters()

def reset_parameters(self):
if isinstance(self.t, Tensor):
self.t.data.fill_(self._init_t)

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

alpha = x
if not isinstance(self.t, (int, float)) or self.t != 1:
alpha = x * self.t
alpha = softmax(alpha, index, ptr, dim_size, dim)
return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum')


class PowerMeanAggregation(Aggregation):
def __init__(self, p: float = 1.0, learn: bool = False):
# TODO Learn distinct `p` per channel.
super().__init__()
self._init_p = p
self.p = Parameter(torch.Tensor(1)) if learn else p
self.reset_parameters()

if isinstance(self.p, Tensor):
self.p.data.fill_(self._init_p)

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

out = self.reduce(x, index, ptr, dim_size, dim, reduce='mean')
if isinstance(self.p, (int, float)) and self.p == 1:
return out
return out.clamp_(min=0, max=100).pow(1. / self.p)
3 changes: 2 additions & 1 deletion torch_geometric/transforms/base_transform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC
from typing import Any


class BaseTransform:
class BaseTransform(ABC):
r"""An abstract base class for writing transforms.
Transforms are a general way to modify and customize
Expand Down

0 comments on commit cb92831

Please sign in to comment.