Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SAGEConv to use LSTMAggregation #4863

Merged
merged 2 commits into from
Jun 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
2 changes: 1 addition & 1 deletion test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,5 @@ def test_lstm_sage_conv():
assert torch.allclose(conv(x, adj.t()), out)

edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 0]])
with pytest.raises(ValueError, match="is not sorted by columns"):
with pytest.raises(ValueError, match="'index' tensor is not sorted"):
conv(x, edge_index)
35 changes: 8 additions & 27 deletions torch_geometric/nn/conv/sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import LSTM
from torch_scatter import scatter
from torch_sparse import SparseTensor, matmul

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, Size
from torch_geometric.utils import to_dense_batch


class SAGEConv(MessagePassing):
Expand Down Expand Up @@ -76,9 +73,6 @@ def __init__(
bias: bool = True,
**kwargs,
):
kwargs['aggr'] = aggr if aggr != 'lstm' else None
super().__init__(**kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
Expand All @@ -88,6 +82,12 @@ def __init__(
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)

if aggr == 'lstm':
kwargs['aggr_kwargs'] = dict(in_channels=in_channels[0],
out_channels=in_channels[0])

super().__init__(aggr, **kwargs)

if self.project:
self.lin = Linear(in_channels[0], in_channels[0], bias=True)

Expand Down Expand Up @@ -140,25 +140,6 @@ def message_and_aggregate(self, adj_t: SparseTensor,
adj_t = adj_t.set_value(None, layout=None)
return matmul(adj_t, x[0], reduce=self.aggr)

def aggregate(self, x: Tensor, index: Tensor, ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
if self.aggr is not None:
return scatter(x, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)

# LSTM aggregation:
if ptr is None and not torch.all(index[:-1] <= index[1:]):
raise ValueError(f"Can not utilize LSTM-style aggregation inside "
f"'{self.__class__.__name__}' in case the "
f"'edge_index' tensor is not sorted by columns. "
f"Run 'sort_edge_index(..., sort_by_row=False)' "
f"in a pre-processing step.")

x, mask = to_dense_batch(x, batch=index, batch_size=dim_size)
out, _ = self.lstm(x)
return out[:, -1]

def __repr__(self) -> str:
aggr = self.aggr if self.aggr is not None else 'lstm'
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, aggr={aggr})')
f'{self.out_channels}, aggr={self.aggr})')