-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for LSTM
aggregator in SAGEConv
#4379
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you :) Left some comments.
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Codecov Report
@@ Coverage Diff @@
## master #4379 +/- ##
==========================================
- Coverage 82.68% 82.58% -0.10%
==========================================
Files 312 312
Lines 16118 16135 +17
==========================================
- Hits 13327 13325 -2
- Misses 2791 2810 +19
Continue to review full report at Codecov.
|
I have converted LSTM's output from tuple to tensor. Kindly review it and let me know if it seems good to you and if something should be changed. Also, setting So, currently with my changes LSTM is only available when the edge_index is a SparseTensor and when message_and_aggregate() is implemented. Please let me know if you have any suggestions for making LSTM work even with Tensors |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! One thing I think we still need to fix is that this currently only works in case the input is a SparseTensor
. IMO, it would be great to have more general support for this directly inside aggregate
rather than inside message_and_aggregate
. For aggr='lstm'
, we would need to set self.fuse = False
such that message_and_aggregate
is not called. Furthermore, we would need to ensure that edge_index_i
is sorted, and fail with a ValueError
otherwise. We can easily check this via (edge_index_j[:-1] <= edge_index[1:]).all())
. WDYT?
torch_geometric/nn/conv/sage_conv.py
Outdated
out = rst.squeeze(0) | ||
return out | ||
|
||
elif self.aggr == 'mean': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif self.aggr == 'mean': | |
else: |
should work here as well.
Hey @rusty1s , @hunarbatra How about converting # Convert edge_index to a sparse tensor,
# this is required for propagate to call message_and_aggregate
if isinstance(edge_index, Tensor):
num_nodes = int(edge_index.max()) + 1
edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=(num_nodes, num_nodes))
# propagate_type: (x: OptPairTensor)
# propagate internally calls message_and_aggregate()
out = self.propagate(edge_index, x=x, size=size)
out = self.lin_l(out) |
@wilcoln Yes, this works as well. The problem with this approach is that we would need to convert to |
@rusty1s Thank you! Yes, definitely having a more general support in aggregate() would be better and I have been working on it. I feel just checking if edge_index_i is sorted is not really working - because it still fails for some common datasets like Cora with dimension issues (when adding root embeddings in update And thanks @wilcoln, that approach definitely works well - but I noticed that when I am converting it to SparseTensor and then applying LSTM - the F1 score is pretty low (for eg: just 14% for supervised Cora) |
@hunarbatra Can you show me an example on how you tried to integrate it into |
Sure, @rusty1s. This is how I've implemented it in message() and aggregate() def message(self, x_j: Tensor, x_i: Tensor) -> Tensor:
return x_j
def aggregate(self, inputs: Tensor, edge_index_i, edge_index_j) -> Tensor:
# LSTM
if (edge_index_i[:-1] <= edge_index_i[1:]).all().tolist() is True:
x_j = inputs[edge_index_j]
x, mask = to_dense_batch(x_j, edge_index_i) # batch=edge_index_i should be "ordered"
_, (hidden, _) = self.lstm(x)
out = hidden.squeeze(0)
return out
else:
raise ValueError("edge_index_i is not ordered") |
LSTM
aggregator in SAGEConv
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I pushed the logic to aggregate
and added a basic test case.
Hi guys, i am using the latest master branch and cannot seem to get Based on the error, it seems like something to do with the dimension of the hidden layer if set via lstm?
|
Do you have a minimal script to reproduce? Please also have a look at https://github.com/pyg-team/pytorch_geometric/blob/master/test/nn/conv/test_sage_conv.py#L59-L73. |
Thanks for the test reference. i found the issue:
using the param aggr seems to fixed the issue
|
This pull request adds the implementation for LSTM, Max Pool and GCN aggregation for SAGEConv which were implemented in the original paper/code Inductive Representation Learning on Large Graphs and the absence of these aggregator functions implementation has been brought up in issue #1147. Addition of these aggregators to SAGEConv would help others looking to use these with PyTorch Geometric SAGEConv, since DGL supports them too!
Thank you to @rusty1s for guiding me with the LSTM implementation! :)
Notes
kwargs.setdefault('aggr', 'mean')
because that was making SAGEConv inherit 'mean' aggr fromMessagePassing
.aggregator_type
argument which is set tomean
by default