Skip to content

Commit

Permalink
Refactor SAGEConv to use LSTMAggregation (#4863)
Browse files Browse the repository at this point in the history
* update

* changelog
  • Loading branch information
rusty1s authored Jun 25, 2022
1 parent d700ddb commit 2003408
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 29 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
2 changes: 1 addition & 1 deletion test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,5 @@ def test_lstm_sage_conv():
assert torch.allclose(conv(x, adj.t()), out)

edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 0]])
with pytest.raises(ValueError, match="is not sorted by columns"):
with pytest.raises(ValueError, match="'index' tensor is not sorted"):
conv(x, edge_index)
35 changes: 8 additions & 27 deletions torch_geometric/nn/conv/sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import LSTM
from torch_scatter import scatter
from torch_sparse import SparseTensor, matmul

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, Size
from torch_geometric.utils import to_dense_batch


class SAGEConv(MessagePassing):
Expand Down Expand Up @@ -76,9 +73,6 @@ def __init__(
bias: bool = True,
**kwargs,
):
kwargs['aggr'] = aggr if aggr != 'lstm' else None
super().__init__(**kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
Expand All @@ -88,6 +82,12 @@ def __init__(
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)

if aggr == 'lstm':
kwargs['aggr_kwargs'] = dict(in_channels=in_channels[0],
out_channels=in_channels[0])

super().__init__(aggr, **kwargs)

if self.project:
self.lin = Linear(in_channels[0], in_channels[0], bias=True)

Expand Down Expand Up @@ -140,25 +140,6 @@ def message_and_aggregate(self, adj_t: SparseTensor,
adj_t = adj_t.set_value(None, layout=None)
return matmul(adj_t, x[0], reduce=self.aggr)

def aggregate(self, x: Tensor, index: Tensor, ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
if self.aggr is not None:
return scatter(x, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)

# LSTM aggregation:
if ptr is None and not torch.all(index[:-1] <= index[1:]):
raise ValueError(f"Can not utilize LSTM-style aggregation inside "
f"'{self.__class__.__name__}' in case the "
f"'edge_index' tensor is not sorted by columns. "
f"Run 'sort_edge_index(..., sort_by_row=False)' "
f"in a pre-processing step.")

x, mask = to_dense_batch(x, batch=index, batch_size=dim_size)
out, _ = self.lstm(x)
return out[:, -1]

def __repr__(self) -> str:
aggr = self.aggr if self.aggr is not None else 'lstm'
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, aggr={aggr})')
f'{self.out_channels}, aggr={self.aggr})')

0 comments on commit 2003408

Please sign in to comment.