diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b6244f6afee..21cbeec8ece2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755)) - Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926)) - Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723)) -- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779)) +- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863)) - Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800)) - Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487)) - Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730)) diff --git a/test/nn/conv/test_sage_conv.py b/test/nn/conv/test_sage_conv.py index aa96b79cd72a..58afa7ae53bd 100644 --- a/test/nn/conv/test_sage_conv.py +++ b/test/nn/conv/test_sage_conv.py @@ -69,5 +69,5 @@ def test_lstm_sage_conv(): assert torch.allclose(conv(x, adj.t()), out) edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 0]]) - with pytest.raises(ValueError, match="is not sorted by columns"): + with pytest.raises(ValueError, match="'index' tensor is not sorted"): conv(x, edge_index) diff --git a/torch_geometric/nn/conv/sage_conv.py b/torch_geometric/nn/conv/sage_conv.py index ceb97d784d92..d7f861e0c8b1 100644 --- a/torch_geometric/nn/conv/sage_conv.py +++ b/torch_geometric/nn/conv/sage_conv.py @@ -1,16 +1,13 @@ -from typing import Optional, Tuple, Union +from typing import Tuple, Union -import torch import torch.nn.functional as F from torch import Tensor from torch.nn import LSTM -from torch_scatter import scatter from torch_sparse import SparseTensor, matmul from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptPairTensor, Size -from torch_geometric.utils import to_dense_batch class SAGEConv(MessagePassing): @@ -76,9 +73,6 @@ def __init__( bias: bool = True, **kwargs, ): - kwargs['aggr'] = aggr if aggr != 'lstm' else None - super().__init__(**kwargs) - self.in_channels = in_channels self.out_channels = out_channels self.normalize = normalize @@ -88,6 +82,12 @@ def __init__( if isinstance(in_channels, int): in_channels = (in_channels, in_channels) + if aggr == 'lstm': + kwargs['aggr_kwargs'] = dict(in_channels=in_channels[0], + out_channels=in_channels[0]) + + super().__init__(aggr, **kwargs) + if self.project: self.lin = Linear(in_channels[0], in_channels[0], bias=True) @@ -140,25 +140,6 @@ def message_and_aggregate(self, adj_t: SparseTensor, adj_t = adj_t.set_value(None, layout=None) return matmul(adj_t, x[0], reduce=self.aggr) - def aggregate(self, x: Tensor, index: Tensor, ptr: Optional[Tensor] = None, - dim_size: Optional[int] = None) -> Tensor: - if self.aggr is not None: - return scatter(x, index, dim=self.node_dim, dim_size=dim_size, - reduce=self.aggr) - - # LSTM aggregation: - if ptr is None and not torch.all(index[:-1] <= index[1:]): - raise ValueError(f"Can not utilize LSTM-style aggregation inside " - f"'{self.__class__.__name__}' in case the " - f"'edge_index' tensor is not sorted by columns. " - f"Run 'sort_edge_index(..., sort_by_row=False)' " - f"in a pre-processing step.") - - x, mask = to_dense_batch(x, batch=index, batch_size=dim_size) - out, _ = self.lstm(x) - return out[:, -1] - def __repr__(self) -> str: - aggr = self.aggr if self.aggr is not None else 'lstm' return (f'{self.__class__.__name__}({self.in_channels}, ' - f'{self.out_channels}, aggr={aggr})') + f'{self.out_channels}, aggr={self.aggr})')