Skip to content

Commit

Permalink
Rename SortAggr to SortAggregation (#5085)
Browse files Browse the repository at this point in the history
* rename sortaggr

* add changelog

* fix text on name
  • Loading branch information
Padarn authored Jul 29, 2022
1 parent 412ae53 commit 3950ff6
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033]))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033]), [#5085](https://github.com/pyg-team/pytorch_geometric/pull/5085))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
14 changes: 7 additions & 7 deletions test/nn/aggr/test_sort.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch

from torch_geometric.nn.aggr import SortAggr
from torch_geometric.nn.aggr import SortAggregation


def test_sort_aggregation_pool():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

aggr = SortAggr(k=5)
assert str(aggr) == 'SortAggr(k=5)'
aggr = SortAggregation(k=5)
assert str(aggr) == 'SortAggregation(k=5)'

out = aggr(x, index)
assert out.size() == (2, 5 * 4)
Expand All @@ -36,8 +36,8 @@ def test_sort_aggregation_pool_smaller_than_k():
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

# Set k which is bigger than both N_1=4 and N_2=6.
aggr = SortAggr(k=10)
assert str(aggr) == 'SortAggr(k=10)'
aggr = SortAggregation(k=10)
assert str(aggr) == 'SortAggregation(k=10)'

out = aggr(x, index)
assert out.size() == (2, 10 * 4)
Expand All @@ -64,8 +64,8 @@ def test_global_sort_pool_dim_size():
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

aggr = SortAggr(k=5)
assert str(aggr) == 'SortAggr(k=5)'
aggr = SortAggregation(k=5)
assert str(aggr) == 'SortAggregation(k=5)'

# expand batch output by 1
out = aggr(x, index, dim_size=3)
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .set2set import Set2Set
from .scaler import DegreeScalerAggregation
from .equilibrium import EquilibriumAggregation
from .sort import SortAggr
from .sort import SortAggregation
from .gmt import GraphMultisetTransformer
from .attention import AttentionalAggregation

Expand All @@ -34,7 +34,7 @@
'LSTMAggregation',
'Set2Set',
'DegreeScalerAggregation',
'SortAggr',
'SortAggregation',
'GraphMultisetTransformer',
'AttentionalAggregation',
'EquilibriumAggregation',
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/aggr/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch_geometric.nn.aggr import Aggregation


class SortAggr(Aggregation):
class SortAggregation(Aggregation):
r"""The pooling operator from the `"An End-to-End Deep Learning
Architecture for Graph Classification"
<https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf>`_ paper,
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
AttentionalAggregation,
GraphMultisetTransformer,
Set2Set,
SortAggr,
SortAggregation,
)

Set2Set = deprecated(
Expand Down Expand Up @@ -36,7 +36,7 @@ def __call__(self, x, batch=None, size=None):
func_name='nn.glob.global_sort_pool',
)
def global_sort_pool(x, index, k):
module = SortAggr(k=k)
module = SortAggregation(k=k)
return module(x, index=index)


Expand Down

0 comments on commit 3950ff6

Please sign in to comment.