Skip to content

Commit

Permalink
Add edge_attr support for ResGatedGraphConv (#8048)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Sep 19, 2023
1 parent 5f157cd commit 8f52944
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `edge_attr` support to `ResGatedGraphConv` ([#8048](https://github.com/pyg-team/pytorch_geometric/pull/8048))
- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054), [#8057](https://github.com/pyg-team/pytorch_geometric/pull/8057))
- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))
- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
Expand Down
34 changes: 19 additions & 15 deletions test/nn/conv/test_res_gated_graph_conv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch

import torch_geometric.typing
Expand All @@ -7,53 +8,56 @@
from torch_geometric.utils import to_torch_csc_tensor


def test_res_gated_graph_conv():
@pytest.mark.parametrize('edge_dim', [None, 4])
def test_res_gated_graph_conv(edge_dim):
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 32)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
edge_attr = torch.randn(edge_index.size(1), edge_dim) if edge_dim else None
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))

conv = ResGatedGraphConv(8, 32)
conv = ResGatedGraphConv(8, 32, edge_dim=edge_dim)
assert str(conv) == 'ResGatedGraphConv(8, 32)'

out = conv(x1, edge_index)
out = conv(x1, edge_index, edge_attr)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, adj1.t(), edge_attr), out, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4))
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor) -> Tensor'
t = '(Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)
assert torch.allclose(jit(x1, edge_index, edge_attr), out, atol=1e-6)

if is_full_test() and torch_geometric.typing.WITH_TORCH_SPARSE:
t = '(Tensor, SparseTensor) -> Tensor'
t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)

# Test bipartite message passing:
adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))

conv = ResGatedGraphConv((8, 32), 32)
conv = ResGatedGraphConv((8, 32), 32, edge_dim=edge_dim)
assert str(conv) == 'ResGatedGraphConv((8, 32), 32)'

out = conv((x1, x2), edge_index)
out = conv((x1, x2), edge_index, edge_attr)
assert out.size() == (2, 32)
assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj1.t(), edge_attr), out, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))
adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 2))
assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(PairTensor, Tensor) -> Tensor'
t = '(PairTensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, edge_attr), out,
atol=1e-6)

if is_full_test() and torch_geometric.typing.WITH_TORCH_SPARSE:
t = '(PairTensor, SparseTensor) -> Tensor'
t = '(PairTensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6)
45 changes: 34 additions & 11 deletions torch_geometric/nn/conv/res_gated_graph_conv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Callable, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Parameter, Sigmoid

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, PairTensor
from torch_geometric.typing import Adj, OptTensor, PairTensor


class ResGatedGraphConv(MessagePassing):
Expand All @@ -33,6 +34,8 @@ class ResGatedGraphConv(MessagePassing):
out_channels (int): Size of each output sample.
act (callable, optional): Gating function :math:`\sigma`.
(default: :meth:`torch.nn.Sigmoid()`)
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). (default: :obj:`None`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
Expand All @@ -55,6 +58,7 @@ def __init__(
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
act: Optional[Callable] = Sigmoid(),
edge_dim: Optional[int] = None,
root_weight: bool = True,
bias: bool = True,
**kwargs,
Expand All @@ -66,14 +70,16 @@ def __init__(
self.in_channels = in_channels
self.out_channels = out_channels
self.act = act
self.edge_dim = edge_dim
self.root_weight = root_weight

if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)

self.lin_key = Linear(in_channels[1], out_channels)
self.lin_query = Linear(in_channels[0], out_channels)
self.lin_value = Linear(in_channels[0], out_channels)
edge_dim = edge_dim if edge_dim is not None else 0
self.lin_key = Linear(in_channels[1] + edge_dim, out_channels)
self.lin_query = Linear(in_channels[0] + edge_dim, out_channels)
self.lin_value = Linear(in_channels[0] + edge_dim, out_channels)

if root_weight:
self.lin_skip = Linear(in_channels[1], out_channels, bias=False)
Expand All @@ -97,16 +103,24 @@ def reset_parameters(self):
if self.bias is not None:
zeros(self.bias)

def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
edge_attr: OptTensor = None) -> Tensor:

if isinstance(x, Tensor):
x: PairTensor = (x, x)

k = self.lin_key(x[1])
q = self.lin_query(x[0])
v = self.lin_value(x[0])
# In case edge features are not given, we can compute key, query and
# value tensors in node-level space, which is a bit more efficient:
if self.edge_dim is None:
k = self.lin_key(x[1])
q = self.lin_query(x[0])
v = self.lin_value(x[0])
else:
k, q, v = x[1], x[0], x[0]

# propagate_type: (k: Tensor, q: Tensor, v: Tensor)
out = self.propagate(edge_index, k=k, q=q, v=v, size=None)
# propagate_type: (k: Tensor, q: Tensor, v: Tensor, edge_attr: OptTensor) # noqa
out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr,
size=None)

if self.root_weight:
out = out + self.lin_skip(x[1])
Expand All @@ -116,5 +130,14 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:

return out

def message(self, k_i: Tensor, q_j: Tensor, v_j: Tensor) -> Tensor:
def message(self, k_i: Tensor, q_j: Tensor, v_j: Tensor,
edge_attr: OptTensor) -> Tensor:

assert (edge_attr is not None) == (self.edge_dim is not None)

if edge_attr is not None:
k_i = self.lin_key(torch.cat([k_i, edge_attr], dim=-1))
q_j = self.lin_query(torch.cat([q_j, edge_attr], dim=-1))
v_j = self.lin_value(torch.cat([v_j, edge_attr], dim=-1))

return self.act(k_i + q_j) * v_j

0 comments on commit 8f52944

Please sign in to comment.