From 1c8b4cd01bc9d83cc9553dd61dcd1445805f5675 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sun, 27 Feb 2022 09:40:23 +0000 Subject: [PATCH 1/3] add virtualnode impl --- torch_geometric/transforms/__init__.py | 2 + torch_geometric/transforms/add_self_loops.py | 3 - torch_geometric/transforms/virtual_node.py | 60 ++++++++++++++++++++ 3 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 torch_geometric/transforms/virtual_node.py diff --git a/torch_geometric/transforms/__init__.py b/torch_geometric/transforms/__init__.py index 59945ce2c261..deb6fe11c28b 100644 --- a/torch_geometric/transforms/__init__.py +++ b/torch_geometric/transforms/__init__.py @@ -47,6 +47,7 @@ from .random_link_split import RandomLinkSplit from .add_metapaths import AddMetaPaths from .largest_connected_components import LargestConnectedComponents +from .virtual_node import VirtualNode __all__ = [ 'BaseTransform', @@ -98,6 +99,7 @@ 'RandomLinkSplit', 'AddMetaPaths', 'LargestConnectedComponents', + 'VirtualNode', ] classes = __all__ diff --git a/torch_geometric/transforms/add_self_loops.py b/torch_geometric/transforms/add_self_loops.py index 3a20b1c2e23f..81579402672a 100644 --- a/torch_geometric/transforms/add_self_loops.py +++ b/torch_geometric/transforms/add_self_loops.py @@ -41,6 +41,3 @@ def __call__(self, data: Union[Data, HeteroData]): setattr(store, self.attr, edge_weight) return data - - def __repr__(self) -> str: - return f'{self.__class__.__name__}()' diff --git a/torch_geometric/transforms/virtual_node.py b/torch_geometric/transforms/virtual_node.py new file mode 100644 index 000000000000..e845f8382ee1 --- /dev/null +++ b/torch_geometric/transforms/virtual_node.py @@ -0,0 +1,60 @@ +from typing import Optional, Union + +import torch +from torch import Tensor + +from torch_geometric.data import Data, HeteroData +from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import add_self_loops + + +class VirtualNode(BaseTransform): + r"""Appends a virtual node to the given homogeneous graph which is + connected to all other nodes in the graph as first described in the + `"Neural Message Passing for Quantum Chemistry" + `_ paper. + The virtual node serves as a global scratch space that each node both reads + from and writes to in every step of message passing. + This allows information to travel long distances during the propagation + phase. + + Node and edge features of the virtual node are added as zero-filled input + features. + Furthermore, special edge types will be added both for in-coming and + out-going information to and from the virtual node. + """ + def __call__(self, data: Data) -> Data: + N, (row, col) = data.num_nodes, data.edge_index + edge_type = data.get('edge_type', torch.zeros_like(row)) + + arange = torch.arange(N, device=row.device) + full = torch.full((N, ), N, dtype=row.dtype, device=row.device) + row = torch.cat([row, arange, full], dim=0) + col = torch.cat([col, full, arange], dim=0) + edge_index = torch.stack([row, col], dim=0) + + C = int(edge_type.max()) + 1 + edge_type_1 = torch.full((N, ), C, dtype=row.dtype, device=row.device) + edge_type_2 = edge_type_1 + 1 + edge_type = torch.cat([edge_type, edge_type_1, edge_type_2], dim=0) + + for key, value in data.items(): + if key == 'edge_index' or key == 'edge_type': + continue + + if data.is_edge_attr(key): + dim = data.__cat_dim__(key, value) + size = list(value.size()) + size[dim] = 2 * data.num_nodes + data[key] = torch.cat([value, value.new_zeros(size)], dim=dim) + + elif data.is_node_attr(key): + dim = data.__cat_dim__(key, value) + size = list(value.size()) + size[dim] = 1 + data[key] = torch.cat([value, value.new_zeros(size)], dim=dim) + + data.edge_index = edge_index + data.edge_type = edge_type + + return data From 77672bdc2c15afb579585f91d2a8c3c650cb865b Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sun, 27 Feb 2022 10:00:09 +0000 Subject: [PATCH 2/3] add test --- test/transforms/test_virtual_node.py | 38 ++++++++++++++++ torch_geometric/transforms/virtual_node.py | 50 ++++++++++++---------- 2 files changed, 65 insertions(+), 23 deletions(-) create mode 100644 test/transforms/test_virtual_node.py diff --git a/test/transforms/test_virtual_node.py b/test/transforms/test_virtual_node.py new file mode 100644 index 000000000000..7111c771ad98 --- /dev/null +++ b/test/transforms/test_virtual_node.py @@ -0,0 +1,38 @@ +import torch + +from torch_geometric.data import Data +from torch_geometric.transforms import VirtualNode + + +def test_virtual_node(): + assert str(VirtualNode()) == 'VirtualNode()' + + x = torch.randn(4, 16) + edge_index = torch.tensor([[2, 0, 2], [3, 1, 0]]) + edge_weight = torch.rand(edge_index.size(1)) + edge_attr = torch.randn(edge_index.size(1), 8) + + data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight, + edge_attr=edge_attr, num_nodes=x.size(0)) + + data = VirtualNode()(data) + assert len(data) == 6 + + assert data.x.size() == (5, 16) + assert torch.allclose(data.x[:4], x) + assert data.x[4:].abs().sum() == 0 + + assert data.edge_index.tolist() == [[2, 0, 2, 0, 1, 2, 3, 4, 4, 4, 4], + [3, 1, 0, 4, 4, 4, 4, 0, 1, 2, 3]] + + assert data.edge_weight.size() == (11, ) + assert torch.allclose(data.edge_weight[:3], edge_weight) + assert data.edge_weight[3:].abs().sum() == 8 + + assert data.edge_attr.size() == (11, 8) + assert torch.allclose(data.edge_attr[:3], edge_attr) + assert data.edge_attr[3:].abs().sum() == 0 + + assert data.num_nodes == 5 + + assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] diff --git a/torch_geometric/transforms/virtual_node.py b/torch_geometric/transforms/virtual_node.py index e845f8382ee1..817557bc359a 100644 --- a/torch_geometric/transforms/virtual_node.py +++ b/torch_geometric/transforms/virtual_node.py @@ -1,18 +1,14 @@ -from typing import Optional, Union - import torch from torch import Tensor -from torch_geometric.data import Data, HeteroData +from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import add_self_loops class VirtualNode(BaseTransform): - r"""Appends a virtual node to the given homogeneous graph which is - connected to all other nodes in the graph as first described in the - `"Neural Message Passing for Quantum Chemistry" - `_ paper. + r"""Appends a virtual node to the given homogeneous graph that is connected + to all other nodes as first described in the `"Neural Message Passing for + Quantum Chemistry" `_ paper. The virtual node serves as a global scratch space that each node both reads from and writes to in every step of message passing. This allows information to travel long distances during the propagation @@ -24,37 +20,45 @@ class VirtualNode(BaseTransform): out-going information to and from the virtual node. """ def __call__(self, data: Data) -> Data: - N, (row, col) = data.num_nodes, data.edge_index + num_nodes, (row, col) = data.num_nodes, data.edge_index edge_type = data.get('edge_type', torch.zeros_like(row)) - arange = torch.arange(N, device=row.device) - full = torch.full((N, ), N, dtype=row.dtype, device=row.device) + arange = torch.arange(num_nodes, device=row.device) + full = row.new_full((num_nodes, ), num_nodes) row = torch.cat([row, arange, full], dim=0) col = torch.cat([col, full, arange], dim=0) edge_index = torch.stack([row, col], dim=0) - C = int(edge_type.max()) + 1 - edge_type_1 = torch.full((N, ), C, dtype=row.dtype, device=row.device) - edge_type_2 = edge_type_1 + 1 - edge_type = torch.cat([edge_type, edge_type_1, edge_type_2], dim=0) + new_type = edge_type.new_full((num_nodes, ), int(edge_type.max()) + 1) + edge_type = torch.cat([edge_type, new_type, new_type + 1], dim=0) for key, value in data.items(): if key == 'edge_index' or key == 'edge_type': continue - if data.is_edge_attr(key): + if isinstance(value, Tensor): dim = data.__cat_dim__(key, value) size = list(value.size()) - size[dim] = 2 * data.num_nodes - data[key] = torch.cat([value, value.new_zeros(size)], dim=dim) - elif data.is_node_attr(key): - dim = data.__cat_dim__(key, value) - size = list(value.size()) - size[dim] = 1 - data[key] = torch.cat([value, value.new_zeros(size)], dim=dim) + fill_value = None + if key == 'edge_weight': + size[dim] = 2 * num_nodes + fill_value = 1. + elif data.is_edge_attr(key): + size[dim] = 2 * num_nodes + fill_value = 0. + elif data.is_node_attr(key): + size[dim] = 1 + fill_value = 0. + + if fill_value is not None: + new_value = value.new_full(size, fill_value) + data[key] = torch.cat([value, new_value], dim=dim) data.edge_index = edge_index data.edge_type = edge_type + if 'num_nodes' in data: + data.num_nodes = data.num_nodes + 1 + return data From dc3d007d0a5f3371dbfe93891a9f54bc9aef1ea2 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sun, 27 Feb 2022 11:04:43 +0100 Subject: [PATCH 3/3] typo --- torch_geometric/transforms/virtual_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/virtual_node.py b/torch_geometric/transforms/virtual_node.py index 817557bc359a..9e4398261b03 100644 --- a/torch_geometric/transforms/virtual_node.py +++ b/torch_geometric/transforms/virtual_node.py @@ -7,7 +7,7 @@ class VirtualNode(BaseTransform): r"""Appends a virtual node to the given homogeneous graph that is connected - to all other nodes as first described in the `"Neural Message Passing for + to all other nodes, as described in the `"Neural Message Passing for Quantum Chemistry" `_ paper. The virtual node serves as a global scratch space that each node both reads from and writes to in every step of message passing.