Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Reduce messages with scatter_add in PyTorch #427

Merged
merged 7 commits into from
Mar 2, 2019

Conversation

lingfanyu
Copy link
Collaborator

@lingfanyu lingfanyu commented Mar 2, 2019

Description

Spmm with uncoalesced sparse matrix in PyTorch has bad performance because the coalesce step is costly. Although coalesce sparse matrix ahead of time could make the forward pass much faster, the backward phase is still slow. Given expression y=spmat * x, in order to calculate gradient of x, we need to do spmat^T * grad_y. But transpose sparse matrix and then coalesce is again very costly.

This PR replaces PyTorch spmm with index_select and scatter_add. This solution is motivated by pytorch_geometric.

Checklist

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [$CATEGORY] (such as [Model], [Doc], [Feature]])
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • To the my best knowledge, examples are either not affected by this change,
    or have been fixed to be compatible with this change

Changes

  • Replace PyTorch spmm with index_select and scatter_add

Copy link
Member

@jermainewang jermainewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jermainewang jermainewang merged commit 3cc32a9 into dmlc:master Mar 2, 2019
@lingfanyu lingfanyu deleted the torch_scatter branch March 2, 2019 03:55
zheng-da pushed a commit to zheng-da/dgl that referenced this pull request Mar 6, 2019
* implement pytorch spmm with gather and scatter add

* fix

* replace torch take with index_select

* comments

* comment about pytorch __getitem__ operator pitfall

* typo
zheng-da pushed a commit to zheng-da/dgl that referenced this pull request Mar 6, 2019
* implement pytorch spmm with gather and scatter add

* fix

* replace torch take with index_select

* comments

* comment about pytorch __getitem__ operator pitfall

* typo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants