diff --git a/CHANGELOG.md b/CHANGELOG.md index eb19a4a4265f..0f3b030d0f67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added the `VariancePreservingAggregation` (VPA) ([#9075](https://github.com/pyg-team/pytorch_geometric/pull/9075)) - Added option to pass custom` from_smiles` functionality to `PCQM4Mv2` and `MoleculeNet` ([#9073](https://github.com/pyg-team/pytorch_geometric/pull/9073)) - Added `group_cat` functionality ([#9029](https://github.com/pyg-team/pytorch_geometric/pull/9029)) - Added support for `EdgeIndex` in `spmm` ([#9026](https://github.com/pyg-team/pytorch_geometric/pull/9026)) diff --git a/test/nn/aggr/test_variance_preserving.py b/test/nn/aggr/test_variance_preserving.py new file mode 100644 index 000000000000..a1ca16c15dda --- /dev/null +++ b/test/nn/aggr/test_variance_preserving.py @@ -0,0 +1,28 @@ +import torch + +from torch_geometric.nn import ( + MeanAggregation, + SumAggregation, + VariancePreservingAggregation, +) + + +def test_variance_preserving(): + x = torch.randn(6, 16) + index = torch.tensor([0, 0, 1, 1, 1, 3]) + ptr = torch.tensor([0, 2, 5, 5, 6]) + + vpa_aggr = VariancePreservingAggregation() + mean_aggr = MeanAggregation() + sum_aggr = SumAggregation() + + out_vpa = vpa_aggr(x, index) + out_mean = mean_aggr(x, index) + out_sum = sum_aggr(x, index) + + # Equivalent formulation: + expected = torch.sqrt(out_mean.abs() * out_sum.abs()) * out_sum.sign() + + assert out_vpa.size() == (4, 16) + assert torch.allclose(out_vpa, expected) + assert torch.allclose(out_vpa, vpa_aggr(x, ptr=ptr)) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index c41d038b8ba2..0b3425849059 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -24,6 +24,7 @@ from .deep_sets import DeepSetsAggregation from .set_transformer import SetTransformerAggregation from .lcm import LCMAggregation +from .variance_preserving import VariancePreservingAggregation __all__ = classes = [ 'Aggregation', @@ -51,4 +52,5 @@ 'DeepSetsAggregation', 'SetTransformerAggregation', 'LCMAggregation', + 'VariancePreservingAggregation', ] diff --git a/torch_geometric/nn/aggr/variance_preserving.py b/torch_geometric/nn/aggr/variance_preserving.py new file mode 100644 index 000000000000..31d1e4c7a7af --- /dev/null +++ b/torch_geometric/nn/aggr/variance_preserving.py @@ -0,0 +1,33 @@ +from typing import Optional + +from torch import Tensor + +from torch_geometric.nn.aggr import Aggregation +from torch_geometric.utils import degree +from torch_geometric.utils._scatter import broadcast + + +class VariancePreservingAggregation(Aggregation): + r"""Performs the Variance Preserving Aggregation (VPA) from the `"GNN-VPA: + A Variance-Preserving Aggregation Strategy for Graph Neural Networks" + `_ paper. + + .. math:: + \mathrm{vpa}(\mathcal{X}) = \frac{1}{\sqrt{|\mathcal{X}|}} + \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i + """ + def forward(self, x: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + + out = self.reduce(x, index, ptr, dim_size, dim, reduce='sum') + + if ptr is not None: + count = ptr.diff().to(out.dtype) + else: + count = degree(index, dim_size, dtype=out.dtype) + + count = count.sqrt().clamp(min=1.0) + count = broadcast(count, ref=out, dim=dim) + + return out / count diff --git a/torch_geometric/utils/_scatter.py b/torch_geometric/utils/_scatter.py index bd32fde5adf5..b4c8f518258a 100644 --- a/torch_geometric/utils/_scatter.py +++ b/torch_geometric/utils/_scatter.py @@ -175,6 +175,7 @@ def scatter( def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor: + dim = ref.dim() + dim if dim < 0 else dim size = ((1, ) * dim) + (-1, ) + ((1, ) * (ref.dim() - dim - 1)) return src.view(size).expand_as(ref)