Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GlobalSortAggr to nn.aggr package. #4957

Merged
merged 24 commits into from
Jul 18, 2022

Conversation

Padarn
Copy link
Contributor

@Padarn Padarn commented Jul 11, 2022

Add GlobalSortAggr to the new nn.aggr packages replace existing global_sort_pool from nn.glob.sort.

Addresses item in #4712

WIP items to address:

  • input dim_size is not checked or used
  • input dim is not yet respected # skipping as to_dense_batch does not support this yet
  • depreciate nn.glob version (replace it in existing implementations?)

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Looks amazing. If possible, let's deprecate the corresponding function in torch_geometric.nn.glob in this PR as well.

torch_geometric/nn/aggr/__init__.py Outdated Show resolved Hide resolved
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
@Padarn
Copy link
Contributor Author

Padarn commented Jul 11, 2022

Makes sense. Will add it to the list. Do you mean to apply a depreciated tag or to remove it from the rest of the codebase? (maybe these two can be done separately?)

@Padarn Padarn changed the title [WIP] Add GlobalSortAggr to nn.aggr package. Add GlobalSortAggr to nn.aggr package. Jul 12, 2022
@Padarn
Copy link
Contributor Author

Padarn commented Jul 12, 2022

Not sure why I'm getting failures in the glob tests - do we want to keep them?

global_sort_pool = deprecated(
details="use 'nn.aggr.GlobalSortAggr' instead",
func_name='nn.glob.global_sort_pool',
)(sort.global_sort_pool)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cannot use the new class here as its not a replacement (new class doesn't take k in the function call)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could make it a kwarg if we want it to work exactly the same way

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just define it here:

@deprecated(...)
def global_sort_pool(...):
    module = GlobalSortAggr(...)
    return module(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah okay this is a better solution, thanks!

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue with failing tests is that there exists multiple test_sort.py files (which is not allowed by pytest). Try removing glob/aggr/test_sort.py.

torch_geometric/nn/aggr/sort.py Outdated Show resolved Hide resolved
torch_geometric/nn/aggr/sort.py Outdated Show resolved Hide resolved
global_sort_pool = deprecated(
details="use 'nn.aggr.GlobalSortAggr' instead",
func_name='nn.glob.global_sort_pool',
)(sort.global_sort_pool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just define it here:

@deprecated(...)
def global_sort_pool(...):
    module = GlobalSortAggr(...)
    return module(...)

@codecov
Copy link

codecov bot commented Jul 12, 2022

Codecov Report

Merging #4957 (3cd30b6) into master (c6675c4) will increase coverage by 0.00%.
The diff coverage is 94.11%.

@@           Coverage Diff           @@
##           master    #4957   +/-   ##
=======================================
  Coverage   82.79%   82.79%           
=======================================
  Files         330      330           
  Lines       17991    18004   +13     
=======================================
+ Hits        14896    14907   +11     
- Misses       3095     3097    +2     
Impacted Files Coverage Δ
torch_geometric/nn/glob/__init__.py 83.33% <60.00%> (-9.53%) ⬇️
torch_geometric/nn/aggr/__init__.py 100.00% <100.00%> (ø)
torch_geometric/nn/aggr/base.py 93.61% <100.00%> (ø)
torch_geometric/nn/aggr/sort.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c6675c4...3cd30b6. Read the comment docs.

@Padarn
Copy link
Contributor Author

Padarn commented Jul 12, 2022

thanks, I'd not have found that without a lot of pain 😓 (sorry about the force push, had a local conflict that was giving me a lot of trouble)

@Padarn
Copy link
Contributor Author

Padarn commented Jul 12, 2022

Looks okay now. May as well remove the old version at this point. WDYT?

@rusty1s
Copy link
Member

rusty1s commented Jul 12, 2022

Yes, please remove it :)

@Padarn
Copy link
Contributor Author

Padarn commented Jul 12, 2022

All done :-)

Copy link
Contributor

@lightaime lightaime left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @Padarn. LGTM!

@Padarn Padarn merged commit 4e70e5d into pyg-team:master Jul 18, 2022
@lightaime
Copy link
Contributor

lightaime commented Jul 28, 2022

nit-pick comment: I think it may be better to rename SortAggr as SortAggregation to be consistent with other Aggregation classes. WDYT?

@Padarn
Copy link
Contributor Author

Padarn commented Jul 29, 2022

I agree @lightaime! So: #5085

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants