From 8442a1d73b0ad9f33c418febd4f3188c5c1db919 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Fri, 15 Apr 2022 04:52:01 +0000 Subject: [PATCH 1/3] hetero isolated support --- test/transforms/test_remove_isolated_nodes.py | 29 +++++- .../transforms/remove_isolated_nodes.py | 95 +++++++++++++++++-- 2 files changed, 113 insertions(+), 11 deletions(-) diff --git a/test/transforms/test_remove_isolated_nodes.py b/test/transforms/test_remove_isolated_nodes.py index 5040a517d52c..103ca41a427b 100644 --- a/test/transforms/test_remove_isolated_nodes.py +++ b/test/transforms/test_remove_isolated_nodes.py @@ -1,6 +1,6 @@ import torch -from torch_geometric.data import Data +from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import RemoveIsolatedNodes @@ -16,3 +16,30 @@ def test_remove_isolated_nodes(): assert data.edge_index.tolist() == [[0, 1, 0], [1, 0, 0]] assert data.edge_attr.tolist() == [1, 2, 4] assert data.x.tolist() == [[1], [3]] + + data = HeteroData() + data['paper'].x = torch.arange(6).type(torch.float) + data['author'].x = torch.arange(6).type(torch.float) + + # first self loop to be isolated, second not, and third not self loop + data['paper', 'paper'].edge_index = torch.Tensor([[0, 1, 2], + [0, 1, + 3]]).type(torch.long) + + # remove isolation of first paper self loop, and add another + data['paper', 'author'].edge_index = torch.Tensor([[1, 3, 5], + [0, 1, + 2]]).type(torch.long) + + # only edge_attr on one type + data['paper', 'author'].edge_attr = torch.Tensor([1, 2, + 3]).type(torch.long) + + # add a duplicate (in node index) edge + data['paper', 'cites', + 'author'].edge_index = torch.Tensor([[5], [2]]).type(torch.long) + + data = RemoveIsolatedNodes()(data) + assert data['paper'].num_nodes == 4 + assert data['author'].num_nodes == 3 + assert torch.allclose(data['paper'].x, torch.Tensor([1, 2, 3, 5])) diff --git a/torch_geometric/transforms/remove_isolated_nodes.py b/torch_geometric/transforms/remove_isolated_nodes.py index 8c6969bdeeb2..62a451d5a89a 100644 --- a/torch_geometric/transforms/remove_isolated_nodes.py +++ b/torch_geometric/transforms/remove_isolated_nodes.py @@ -1,7 +1,9 @@ import re +from typing import Union import torch +from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import remove_isolated_nodes @@ -11,18 +13,91 @@ class RemoveIsolatedNodes(BaseTransform): r"""Removes isolated nodes from the graph (functional name: :obj:`remove_isolated_nodes`).""" - def __call__(self, data): - num_nodes = data.num_nodes - out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes) - data.edge_index, data.edge_attr, mask = out + def __call__(self, data: Union[Data, + HeteroData]) -> Union[Data, HeteroData]: - if hasattr(data, '__num_nodes__'): - data.num_nodes = int(mask.sum()) + if isinstance(data, Data): + num_nodes = data.num_nodes + out = remove_isolated_nodes(data.edge_index, data.edge_attr, + num_nodes) + data.edge_index, data.edge_attr, mask = out - for key, item in data: + if hasattr(data, '__num_nodes__'): + data.num_nodes = int(mask.sum()) + + for key, item in data: + if bool(re.search('edge', key)): + continue + if torch.is_tensor(item) and item.size(0) == num_nodes: + data[key] = item[mask] + + return data + + elif isinstance(data, HeteroData): + return remove_hetero_isolated_nodes(data) + + raise TypeError(f"`RemoveIsolatedNodes` invalid type: {type(data)}") + + +def remove_hetero_isolated_nodes(data: HeteroData): + r"""Removes the isolated nodes from the heterogenous graph + given by :attr:`data`. + Self-loops are preserved onlt for non-isolated nodes. Nodes with only + a self loop are removed. + + Args: + data (HeteroData): The graph to remove nodes from. + + :rtype: HeteroData + """ + device = data[data.edge_types[0]].edge_index.device + + for node_type in data.node_types: + + num_nodes = data[node_type].num_nodes + mask = torch.zeros(num_nodes, dtype=torch.bool, device=device) + + for edge_type in data.edge_types: + if 'edge_index' in data[edge_type]: + edge_index = data[edge_type].edge_index + if edge_type[0] == edge_type[-1] == node_type: + loop_mask = torch.where(edge_index[0] != edge_index[1]) + mask[edge_index[0][loop_mask[0]]] = 1 + elif edge_type[0] == node_type: + mask[edge_index[0]] = 1 + elif edge_type[-1] == node_type: + mask[edge_index[1]] = 1 + + assoc = torch.full((num_nodes, ), -1, dtype=torch.long, + device=mask.device) + + assoc[mask] = torch.arange(mask.sum(), device=assoc.device) + + for key, values in data[node_type].items(): if bool(re.search('edge', key)): continue - if torch.is_tensor(item) and item.size(0) == num_nodes: - data[key] = item[mask] + if torch.is_tensor(values) and values.size( + 0) == data[node_type].num_nodes: + data[node_type][key] = values[mask] + + for edge_type in data.edge_types: + if 'edge_index' in data[edge_type]: + edge_index = data[edge_type].edge_index + + if edge_type[0] == node_type: + edge_index[0] = assoc[edge_index[0]] + if edge_type[-1] == node_type: + edge_index[1] = assoc[edge_index[1]] + + data[edge_type].edge_index = edge_index + + if edge_type[0] == edge_type[-1] == node_type: + loop_mask = edge_index[0] != -1 + data[edge_type].edge_index = edge_index[:, loop_mask] + if 'edge_attr' in data[edge_type]: + edge_attr = data[edge_type].edge_attr + print(edge_attr.size()) + print(loop_mask.size()) + data[edge_type].edge_attr = edge_attr[loop_mask] - return data + return data From be3e23e7f40a4e1bb05e24846c1476341191714d Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 20 Apr 2022 13:30:39 +0200 Subject: [PATCH 2/3] update --- test/transforms/test_remove_isolated_nodes.py | 62 +++++----- .../transforms/remove_isolated_nodes.py | 116 ++++++------------ 2 files changed, 72 insertions(+), 106 deletions(-) diff --git a/test/transforms/test_remove_isolated_nodes.py b/test/transforms/test_remove_isolated_nodes.py index 103ca41a427b..fbaed4a4ec96 100644 --- a/test/transforms/test_remove_isolated_nodes.py +++ b/test/transforms/test_remove_isolated_nodes.py @@ -5,41 +5,47 @@ def test_remove_isolated_nodes(): - assert RemoveIsolatedNodes().__repr__() == 'RemoveIsolatedNodes()' + assert str(RemoveIsolatedNodes()) == 'RemoveIsolatedNodes()' + + data = Data() + data.x = torch.arange(3) + data.edge_index = torch.tensor([[0, 2], [2, 0]]) + data.edge_attr = torch.arange(2) - edge_index = torch.tensor([[0, 2, 1, 0], [2, 0, 1, 0]]) - edge_attr = torch.tensor([1, 2, 3, 4]) - x = torch.tensor([[1], [2], [3]]) - data = Data(edge_index=edge_index, edge_attr=edge_attr, x=x) data = RemoveIsolatedNodes()(data) + assert len(data) == 3 - assert data.edge_index.tolist() == [[0, 1, 0], [1, 0, 0]] - assert data.edge_attr.tolist() == [1, 2, 4] - assert data.x.tolist() == [[1], [3]] + assert data.x.tolist() == [0, 2] + assert data.edge_index.tolist() == [[0, 1], [1, 0]] + assert data.edge_attr.tolist() == [0, 1] + +def test_remove_isolated_nodes_in_hetero_data(): data = HeteroData() - data['paper'].x = torch.arange(6).type(torch.float) - data['author'].x = torch.arange(6).type(torch.float) - # first self loop to be isolated, second not, and third not self loop - data['paper', 'paper'].edge_index = torch.Tensor([[0, 1, 2], - [0, 1, - 3]]).type(torch.long) + data['p'].x = torch.arange(6) + data['a'].x = torch.arange(6) + data['i'].num_nodes = 4 - # remove isolation of first paper self loop, and add another - data['paper', 'author'].edge_index = torch.Tensor([[1, 3, 5], - [0, 1, - 2]]).type(torch.long) + # isolated paper nodes: {4} + # isolated author nodes: {3, 4, 5} + # isolated institution nodes: {0, 1, 2, 3} + data['p', '1', 'p'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 3]]) + data['p', '2', 'a'].edge_index = torch.tensor([[1, 3, 5], [0, 1, 2]]) + data['p', '2', 'a'].edge_attr = torch.arange(3) + data['p', '3', 'a'].edge_index = torch.tensor([[5], [2]]) - # only edge_attr on one type - data['paper', 'author'].edge_attr = torch.Tensor([1, 2, - 3]).type(torch.long) + data = RemoveIsolatedNodes()(data) - # add a duplicate (in node index) edge - data['paper', 'cites', - 'author'].edge_index = torch.Tensor([[5], [2]]).type(torch.long) + assert len(data) == 4 + assert data['p'].num_nodes == 5 + assert data['a'].num_nodes == 3 + assert data['i'].num_nodes == 0 - data = RemoveIsolatedNodes()(data) - assert data['paper'].num_nodes == 4 - assert data['author'].num_nodes == 3 - assert torch.allclose(data['paper'].x, torch.Tensor([1, 2, 3, 5])) + assert data['p'].x.tolist() == [0, 1, 2, 3, 5] + assert data['a'].x.tolist() == [0, 1, 2] + + assert data['1'].edge_index.tolist() == [[0, 1, 2], [0, 1, 3]] + assert data['2'].edge_index.tolist() == [[1, 3, 4], [0, 1, 2]] + assert data['2'].edge_attr.tolist() == [0, 1, 2] + assert data['3'].edge_index.tolist() == [[4], [2]] diff --git a/torch_geometric/transforms/remove_isolated_nodes.py b/torch_geometric/transforms/remove_isolated_nodes.py index 62a451d5a89a..c8496f43c569 100644 --- a/torch_geometric/transforms/remove_isolated_nodes.py +++ b/torch_geometric/transforms/remove_isolated_nodes.py @@ -1,4 +1,4 @@ -import re +from collections import defaultdict from typing import Union import torch @@ -6,98 +6,58 @@ from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import remove_isolated_nodes @functional_transform('remove_isolated_nodes') class RemoveIsolatedNodes(BaseTransform): r"""Removes isolated nodes from the graph (functional name: :obj:`remove_isolated_nodes`).""" - def __call__(self, data: Union[Data, - HeteroData]) -> Union[Data, HeteroData]: - - if isinstance(data, Data): - num_nodes = data.num_nodes - out = remove_isolated_nodes(data.edge_index, data.edge_attr, - num_nodes) - data.edge_index, data.edge_attr, mask = out - - if hasattr(data, '__num_nodes__'): - data.num_nodes = int(mask.sum()) - - for key, item in data: - if bool(re.search('edge', key)): - continue - if torch.is_tensor(item) and item.size(0) == num_nodes: - data[key] = item[mask] - - return data - - elif isinstance(data, HeteroData): - return remove_hetero_isolated_nodes(data) - - raise TypeError(f"`RemoveIsolatedNodes` invalid type: {type(data)}") - - -def remove_hetero_isolated_nodes(data: HeteroData): - r"""Removes the isolated nodes from the heterogenous graph - given by :attr:`data`. - Self-loops are preserved onlt for non-isolated nodes. Nodes with only - a self loop are removed. - - Args: - data (HeteroData): The graph to remove nodes from. - - :rtype: HeteroData - """ - device = data[data.edge_types[0]].edge_index.device + def __call__(self, data: Union[Data, HeteroData]): + # Gather all nodes that occur in at least one edge (across all types): + n_id_dict = defaultdict(list) + for store in data.edge_stores: + if 'edge_index' not in store: + continue - for node_type in data.node_types: + if store._key is None: + src = dst = None + else: + src, _, dst = store._key - num_nodes = data[node_type].num_nodes - mask = torch.zeros(num_nodes, dtype=torch.bool, device=device) + n_id_dict[src].append(store.edge_index[0]) + n_id_dict[dst].append(store.edge_index[1]) - for edge_type in data.edge_types: - if 'edge_index' in data[edge_type]: - edge_index = data[edge_type].edge_index - if edge_type[0] == edge_type[-1] == node_type: - loop_mask = torch.where(edge_index[0] != edge_index[1]) - mask[edge_index[0][loop_mask[0]]] = 1 - elif edge_type[0] == node_type: - mask[edge_index[0]] = 1 - elif edge_type[-1] == node_type: - mask[edge_index[1]] = 1 + n_id_dict = {k: torch.cat(v).unique() for k, v in n_id_dict.items()} - assoc = torch.full((num_nodes, ), -1, dtype=torch.long, - device=mask.device) + n_map_dict = {} + for store in data.node_stores: + if store._key not in n_id_dict: + n_id_dict[store._key] = torch.empty((0, ), dtype=torch.long) - assoc[mask] = torch.arange(mask.sum(), device=assoc.device) + idx = n_id_dict[store._key] + mapping = idx.new_zeros(data.num_nodes) + mapping[idx] = torch.arange(idx.numel()) + n_map_dict[store._key] = mapping - for key, values in data[node_type].items(): - if bool(re.search('edge', key)): + for store in data.edge_stores: + if 'edge_index' not in store: continue - if torch.is_tensor(values) and values.size( - 0) == data[node_type].num_nodes: - data[node_type][key] = values[mask] - for edge_type in data.edge_types: - if 'edge_index' in data[edge_type]: - edge_index = data[edge_type].edge_index + if store._key is None: + src = dst = None + else: + src, _, dst = store._key - if edge_type[0] == node_type: - edge_index[0] = assoc[edge_index[0]] - if edge_type[-1] == node_type: - edge_index[1] = assoc[edge_index[1]] + row = n_map_dict[src][store.edge_index[0]] + col = n_map_dict[dst][store.edge_index[1]] + store.edge_index = torch.stack([row, col], dim=0) - data[edge_type].edge_index = edge_index + for store in data.node_stores: + for key, value in store.items(): + if key == 'num_nodes': + store.num_nodes = n_id_dict[store._key].numel() - if edge_type[0] == edge_type[-1] == node_type: - loop_mask = edge_index[0] != -1 - data[edge_type].edge_index = edge_index[:, loop_mask] - if 'edge_attr' in data[edge_type]: - edge_attr = data[edge_type].edge_attr - print(edge_attr.size()) - print(loop_mask.size()) - data[edge_type].edge_attr = edge_attr[loop_mask] + elif store.is_node_attr(key): + store[key] = value[n_id_dict[store._key]] - return data + return data From a38c5927c24517df54add8130e1fc2deb7a813f6 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 20 Apr 2022 13:32:01 +0200 Subject: [PATCH 3/3] typo --- torch_geometric/transforms/remove_isolated_nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/remove_isolated_nodes.py b/torch_geometric/transforms/remove_isolated_nodes.py index c8496f43c569..bda19f68b8e7 100644 --- a/torch_geometric/transforms/remove_isolated_nodes.py +++ b/torch_geometric/transforms/remove_isolated_nodes.py @@ -36,7 +36,7 @@ def __call__(self, data: Union[Data, HeteroData]): idx = n_id_dict[store._key] mapping = idx.new_zeros(data.num_nodes) - mapping[idx] = torch.arange(idx.numel()) + mapping[idx] = torch.arange(idx.numel(), device=mapping.device) n_map_dict[store._key] = mapping for store in data.edge_stores: