-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate nn.glob
package
#5039
Changes from 11 commits
814bb12
b281587
eada40a
d8b0218
72edcab
f1c453a
01a6482
9fd15d4
99b9d5e
4e04bcd
44fb56e
fdf7379
108abff
f8bc2d7
702c00e
0357f32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +0,0 @@ | ||
import torch | ||
|
||
from torch_geometric.nn.aggr import SortAggr | ||
|
||
|
||
def test_global_sort_pool(): | ||
Padarn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
N_1, N_2 = 4, 6 | ||
x = torch.randn(N_1 + N_2, 4) | ||
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) | ||
|
||
aggr = SortAggr(k=5) | ||
assert str(aggr) == 'SortAggr(k=5)' | ||
|
||
out = aggr(x, index) | ||
assert out.size() == (2, 5 * 4) | ||
|
||
out_dim = out = aggr(x, index, dim=0) | ||
assert torch.allclose(out_dim, out) | ||
|
||
out = out.view(2, 5, 4) | ||
|
||
# First graph output has been filled up with zeros. | ||
assert out[0, -1].tolist() == [0, 0, 0, 0] | ||
|
||
# Nodes are sorted. | ||
expected = 3 - torch.arange(4) | ||
assert out[0, :4, -1].argsort().tolist() == expected.tolist() | ||
|
||
expected = 4 - torch.arange(5) | ||
assert out[1, :, -1].argsort().tolist() == expected.tolist() | ||
|
||
|
||
def test_global_sort_pool_smaller_than_k(): | ||
N_1, N_2 = 4, 6 | ||
x = torch.randn(N_1 + N_2, 4) | ||
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) | ||
|
||
# Set k which is bigger than both N_1=4 and N_2=6. | ||
aggr = SortAggr(k=10) | ||
assert str(aggr) == 'SortAggr(k=10)' | ||
|
||
out = aggr(x, index) | ||
assert out.size() == (2, 10 * 4) | ||
|
||
out_dim = out = aggr(x, index, dim=0) | ||
assert torch.allclose(out_dim, out) | ||
|
||
out = out.view(2, 10, 4) | ||
|
||
# Both graph outputs have been filled up with zeros. | ||
assert out[0, -1].tolist() == [0, 0, 0, 0] | ||
assert out[1, -1].tolist() == [0, 0, 0, 0] | ||
|
||
# Nodes are sorted. | ||
expected = 3 - torch.arange(4) | ||
assert out[0, :4, -1].argsort().tolist() == expected.tolist() | ||
|
||
expected = 5 - torch.arange(6) | ||
assert out[1, :6, -1].argsort().tolist() == expected.tolist() | ||
|
||
|
||
def test_global_sort_pool_dim_size(): | ||
N_1, N_2 = 4, 6 | ||
x = torch.randn(N_1 + N_2, 4) | ||
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) | ||
|
||
aggr = SortAggr(k=5) | ||
assert str(aggr) == 'SortAggr(k=5)' | ||
|
||
# expand batch output by 1 | ||
out = aggr(x, index, dim_size=3) | ||
assert out.size() == (3, 5 * 4) | ||
|
||
out = out.view(3, 5, 4) | ||
|
||
# Both first and last graph outputs have been filled up with zeros. | ||
assert out[0, -1].tolist() == [0, 0, 0, 0] | ||
assert out[2, -1].tolist() == [0, 0, 0, 0] | ||
This file was deleted.
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,8 @@ | |
from .asap import ASAPooling | ||
from .pan_pool import PANPooling | ||
from .mem_pool import MemPooling | ||
from .glob import (global_max_pool, global_mean_pool, global_add_pool, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry, this is not what I meant. I meant:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea here is that we keep the implementations of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm sorry I don't fully understand the distinction between these. In both cases you can
is it that you want to keep this?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or you just want to clearly separate the deprecations? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we only need to |
||
GraphMultisetTransformer, Set2Set, GlobalAttention) | ||
|
||
try: | ||
import torch_cluster | ||
|
@@ -263,6 +265,12 @@ def nearest(x: Tensor, y: Tensor, batch_x: OptTensor = None, | |
'radius', | ||
'radius_graph', | ||
'nearest', | ||
'global_max_pool', | ||
'global_add_pool', | ||
'global_mean_pool', | ||
'GraphMultisetTransformer', | ||
'Set2Set', | ||
'GlobalAttention', | ||
] | ||
|
||
classes = __all__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please clarify why the tests are removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mistake, added back.