diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b725aea3f0b..2c39dcb3e8aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Fixed `GINEConv` bug with non-sequential input ([#5154](https://github.com/pyg-team/pytorch_geometric/pull/5154)] - Improved error message ([#5095](https://github.com/pyg-team/pytorch_geometric/pull/5095)) - Fixed `HGTLoader` bug which produced outputs with missing edge types ([#5067](https://github.com/pyg-team/pytorch_geometric/pull/5067)) - Fixed dynamic inheritance issue in data batching ([#5051](https://github.com/pyg-team/pytorch_geometric/pull/5051)) diff --git a/test/nn/conv/test_gin_conv.py b/test/nn/conv/test_gin_conv.py index d820f01b5f6f..2d8a9a3c30cd 100644 --- a/test/nn/conv/test_gin_conv.py +++ b/test/nn/conv/test_gin_conv.py @@ -126,6 +126,11 @@ def test_gine_conv_edge_dim(): out = conv(x, edge_index, edge_attr) assert out.size() == (4, 32) + nn = Lin(16, 32) + conv = GINEConv(nn, train_eps=True, edge_dim=8) + out = conv(x, edge_index, edge_attr) + assert out.size() == (4, 32) + def test_static_gin_conv(): x = torch.randn(3, 4, 16) diff --git a/torch_geometric/nn/conv/gin_conv.py b/torch_geometric/nn/conv/gin_conv.py index 2ef6049cb665..7675b2020016 100644 --- a/torch_geometric/nn/conv/gin_conv.py +++ b/torch_geometric/nn/conv/gin_conv.py @@ -130,8 +130,9 @@ class GINEConv(MessagePassing): - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ - def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, - edge_dim: Optional[int] = None, **kwargs): + def __init__(self, nn: torch.nn.Module, eps: float = 0., + train_eps: bool = False, edge_dim: Optional[int] = None, + **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.nn = nn @@ -141,11 +142,16 @@ def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, else: self.register_buffer('eps', torch.Tensor([eps])) if edge_dim is not None: - if hasattr(self.nn[0], 'in_features'): - in_channels = self.nn[0].in_features + if isinstance(self.nn, torch.nn.Sequential): + nn = self.nn[0] + if hasattr(nn, 'in_features'): + in_channels = nn.in_features + elif hasattr(nn, 'in_channels'): + in_channels = nn.in_channels else: - in_channels = self.nn[0].in_channels + raise ValueError("Could not infer input channels from `nn`.") self.lin = Linear(edge_dim, in_channels) + else: self.lin = None self.reset_parameters()