-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
torch_geometric.nn.aggr
package with base class (#4687)
* initial commit * update * changelog * Added basic aggrs, gen aggrs and pna aggrs * Formatted * Formatted * Added test for aggr class * Formatted * update * update * update * update * update * docstring * typo Co-authored-by: lightaime <lightaime@gmail.com>
- Loading branch information
Showing
7 changed files
with
239 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.nn import ( | ||
MaxAggregation, | ||
MeanAggregation, | ||
MinAggregation, | ||
PowerMeanAggregation, | ||
SoftmaxAggregation, | ||
StdAggregation, | ||
SumAggregation, | ||
VarAggregation, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize('Aggregation', [ | ||
MeanAggregation, SumAggregation, MaxAggregation, MinAggregation, | ||
VarAggregation, StdAggregation | ||
]) | ||
def test_basic_aggregation(Aggregation): | ||
x = torch.randn(6, 16) | ||
index = torch.tensor([0, 0, 1, 1, 1, 2]) | ||
ptr = torch.tensor([0, 2, 5, 6]) | ||
|
||
aggr = Aggregation() | ||
assert str(aggr) == f'{Aggregation.__name__}()' | ||
|
||
out = aggr(x, index) | ||
assert out.size() == (3, x.size(1)) | ||
assert torch.allclose(out, aggr(x, ptr=ptr)) | ||
|
||
|
||
@pytest.mark.parametrize('Aggregation', | ||
[SoftmaxAggregation, PowerMeanAggregation]) | ||
@pytest.mark.parametrize('learn', [True, False]) | ||
def test_gen_aggregation(Aggregation, learn): | ||
x = torch.randn(6, 16) | ||
index = torch.tensor([0, 0, 1, 1, 1, 2]) | ||
ptr = torch.tensor([0, 2, 5, 6]) | ||
|
||
aggr = Aggregation(learn=learn) | ||
assert str(aggr) == f'{Aggregation.__name__}()' | ||
|
||
out = aggr(x, index) | ||
assert out.size() == (3, x.size(1)) | ||
assert torch.allclose(out, aggr(x, ptr=ptr)) | ||
|
||
if learn: | ||
out.mean().backward() | ||
for param in aggr.parameters(): | ||
assert not torch.isnan(param.grad).any() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from .base import Aggregation | ||
from .basic import ( | ||
MeanAggregation, | ||
SumAggregation, | ||
MaxAggregation, | ||
MinAggregation, | ||
VarAggregation, | ||
StdAggregation, | ||
SoftmaxAggregation, | ||
PowerMeanAggregation, | ||
) | ||
|
||
__all__ = classes = [ | ||
'Aggregation', | ||
'MeanAggregation', | ||
'SumAggregation', | ||
'MaxAggregation', | ||
'MinAggregation', | ||
'VarAggregation', | ||
'StdAggregation', | ||
'SoftmaxAggregation', | ||
'PowerMeanAggregation', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch_scatter import scatter, segment_csr | ||
|
||
|
||
class Aggregation(torch.nn.Module, ABC): | ||
r"""An abstract base class for implementing custom aggregations.""" | ||
@abstractmethod | ||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
r""" | ||
Args: | ||
x (torch.Tensor): The source tensor. | ||
index (torch.LongTensor, optional): The indices of elements for | ||
applying the aggregation. | ||
One of :obj:`index` or `ptr` must be defined. | ||
(default: :obj:`None`) | ||
ptr (torch.LongTensor, optional): If given, computes the | ||
aggregation based on sorted inputs in CSR representation. | ||
One of :obj:`index` or `ptr` must be defined. | ||
(default: :obj:`None`) | ||
dim_size (int, optional): The size of the output tensor at | ||
dimension :obj:`dim` after aggregation. (default: :obj:`None`) | ||
dim (int, optional): The dimension in which to aggregate. | ||
(default: :obj:`-2`) | ||
""" | ||
pass | ||
|
||
def reset_parameters(self): | ||
pass | ||
|
||
def reduce(self, x: Tensor, index: Optional[Tensor] = None, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2, reduce: str = 'add') -> Tensor: | ||
|
||
assert index is not None or ptr is not None | ||
|
||
if ptr is not None: | ||
ptr = expand_left(ptr, dim, dims=x.dim()) | ||
return segment_csr(x, ptr, reduce=reduce) | ||
|
||
if index is not None: | ||
return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce) | ||
|
||
raise ValueError(f"Error in '{self.__class__.__name__}': " | ||
f"One of 'index' or 'ptr' must be defined") | ||
|
||
def __repr__(self) -> str: | ||
return f'{self.__class__.__name__}()' | ||
|
||
|
||
############################################################################### | ||
|
||
|
||
def expand_left(ptr: Tensor, dim: int, dims: int) -> Tensor: | ||
for _ in range(dims + dim if dim < 0 else dim): | ||
ptr = ptr.unsqueeze(0) | ||
return ptr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn import Parameter | ||
|
||
from torch_geometric.nn.aggr import Aggregation | ||
from torch_geometric.utils import softmax | ||
|
||
|
||
class MeanAggregation(Aggregation): | ||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
return self.reduce(x, index, ptr, dim_size, dim, reduce='mean') | ||
|
||
|
||
class SumAggregation(Aggregation): | ||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
return self.reduce(x, index, ptr, dim_size, dim, reduce='sum') | ||
|
||
|
||
class MaxAggregation(Aggregation): | ||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
return self.reduce(x, index, ptr, dim_size, dim, reduce='max') | ||
|
||
|
||
class MinAggregation(Aggregation): | ||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
return self.reduce(x, index, ptr, dim_size, dim, reduce='min') | ||
|
||
|
||
class VarAggregation(Aggregation): | ||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
|
||
mean = self.reduce(x, index, ptr, dim_size, dim, reduce='mean') | ||
mean_2 = self.reduce(x * x, index, ptr, dim_size, dim, reduce='mean') | ||
return mean_2 - mean * mean | ||
|
||
|
||
class StdAggregation(VarAggregation): | ||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
|
||
var = super().forward(x, index, ptr=ptr, dim_size=dim_size, dim=dim) | ||
return torch.sqrt(var.relu() + 1e-5) | ||
|
||
|
||
class SoftmaxAggregation(Aggregation): | ||
def __init__(self, t: float = 1.0, learn: bool = False): | ||
# TODO Learn distinct `t` per channel. | ||
super().__init__() | ||
self._init_t = t | ||
self.t = Parameter(torch.Tensor(1)) if learn else t | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
if isinstance(self.t, Tensor): | ||
self.t.data.fill_(self._init_t) | ||
|
||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
|
||
alpha = x | ||
if not isinstance(self.t, (int, float)) or self.t != 1: | ||
alpha = x * self.t | ||
alpha = softmax(alpha, index, ptr, dim_size, dim) | ||
return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum') | ||
|
||
|
||
class PowerMeanAggregation(Aggregation): | ||
def __init__(self, p: float = 1.0, learn: bool = False): | ||
# TODO Learn distinct `p` per channel. | ||
super().__init__() | ||
self._init_p = p | ||
self.p = Parameter(torch.Tensor(1)) if learn else p | ||
self.reset_parameters() | ||
|
||
if isinstance(self.p, Tensor): | ||
self.p.data.fill_(self._init_p) | ||
|
||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
|
||
out = self.reduce(x, index, ptr, dim_size, dim, reduce='mean') | ||
if isinstance(self.p, (int, float)) and self.p == 1: | ||
return out | ||
return out.clamp_(min=0, max=100).pow(1. / self.p) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters