Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate nn.glob package #5039

Merged
merged 16 commits into from
Jul 26, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973),[#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [5000](https://github.com/pyg-team/pytorch_geometric/pull/5000))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039)
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
15 changes: 0 additions & 15 deletions docs/source/modules/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,6 @@ Normalization Layers
:undoc-members:
:exclude-members: training

Global Pooling Layers
---------------------

.. currentmodule:: torch_geometric.nn.glob
.. autosummary::
:nosignatures:
{% for cls in torch_geometric.nn.glob.classes %}
{{ cls }}
{% endfor %}

.. automodule:: torch_geometric.nn.glob
:members:
:undoc-members:
:exclude-members: training

Pooling Layers
--------------

Expand Down
78 changes: 0 additions & 78 deletions test/nn/aggr/test_sort.py
Original file line number Diff line number Diff line change
@@ -1,78 +0,0 @@
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify why the tests are removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistake, added back.


from torch_geometric.nn.aggr import SortAggr


def test_global_sort_pool():
Padarn marked this conversation as resolved.
Show resolved Hide resolved
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

aggr = SortAggr(k=5)
assert str(aggr) == 'SortAggr(k=5)'

out = aggr(x, index)
assert out.size() == (2, 5 * 4)

out_dim = out = aggr(x, index, dim=0)
assert torch.allclose(out_dim, out)

out = out.view(2, 5, 4)

# First graph output has been filled up with zeros.
assert out[0, -1].tolist() == [0, 0, 0, 0]

# Nodes are sorted.
expected = 3 - torch.arange(4)
assert out[0, :4, -1].argsort().tolist() == expected.tolist()

expected = 4 - torch.arange(5)
assert out[1, :, -1].argsort().tolist() == expected.tolist()


def test_global_sort_pool_smaller_than_k():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

# Set k which is bigger than both N_1=4 and N_2=6.
aggr = SortAggr(k=10)
assert str(aggr) == 'SortAggr(k=10)'

out = aggr(x, index)
assert out.size() == (2, 10 * 4)

out_dim = out = aggr(x, index, dim=0)
assert torch.allclose(out_dim, out)

out = out.view(2, 10, 4)

# Both graph outputs have been filled up with zeros.
assert out[0, -1].tolist() == [0, 0, 0, 0]
assert out[1, -1].tolist() == [0, 0, 0, 0]

# Nodes are sorted.
expected = 3 - torch.arange(4)
assert out[0, :4, -1].argsort().tolist() == expected.tolist()

expected = 5 - torch.arange(6)
assert out[1, :6, -1].argsort().tolist() == expected.tolist()


def test_global_sort_pool_dim_size():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

aggr = SortAggr(k=5)
assert str(aggr) == 'SortAggr(k=5)'

# expand batch output by 1
out = aggr(x, index, dim_size=3)
assert out.size() == (3, 5 * 4)

out = out.view(3, 5, 4)

# Both first and last graph outputs have been filled up with zeros.
assert out[0, -1].tolist() == [0, 0, 0, 0]
assert out[2, -1].tolist() == [0, 0, 0, 0]
72 changes: 0 additions & 72 deletions test/nn/glob/test_glob.py

This file was deleted.

1 change: 1 addition & 0 deletions test/nn/models/test_gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_gnn_explainer_with_existing_self_loops(model, return_type):
[0, 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)

assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
Expand Down
1 change: 0 additions & 1 deletion torch_geometric/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .aggr import * # noqa
from .conv import * # noqa
from .norm import * # noqa
from .glob import * # noqa
from .pool import * # noqa
from .unpool import * # noqa
from .dense import * # noqa
Expand Down
49 changes: 0 additions & 49 deletions torch_geometric/nn/glob/__init__.py

This file was deleted.

8 changes: 8 additions & 0 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .asap import ASAPooling
from .pan_pool import PANPooling
from .mem_pool import MemPooling
from .glob import (global_max_pool, global_mean_pool, global_add_pool,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, this is not what I meant. I meant:

  • Move global_add_pool, global_mean_pool, and global_max_pool to glob/glob.py to pool/glob.py`.
  • Move glob/__init__.py to nn/glob.py and keep the deprecations in there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is that we keep the implementations of global_*_pool as they were (since they are so heavily used).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm sorry I don't fully understand the distinction between these. In both cases you can

from torch_geometric.nn import global_max_pool

is it that you want to keep this?

from torch_geometric.nn.glob import global_max_pool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you just want to clearly separate the deprecations?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we only need to
deprecate the import of torch_geometric.nn.glob.

GraphMultisetTransformer, Set2Set, GlobalAttention)

try:
import torch_cluster
Expand Down Expand Up @@ -263,6 +265,12 @@ def nearest(x: Tensor, y: Tensor, batch_x: OptTensor = None,
'radius',
'radius_graph',
'nearest',
'global_max_pool',
'global_add_pool',
'global_mean_pool',
'GraphMultisetTransformer',
'Set2Set',
'GlobalAttention',
]

classes = __all__
Loading