From 5fe70770e2d8b6c8abe11cb57a1e2d2e39a4e353 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Wed, 20 Apr 2022 19:33:01 +0800 Subject: [PATCH] Add `HeteroData` support to `RemoveIsolatedNodes` (#4479) * hetero isolated support * update * typo Co-authored-by: rusty1s --- test/transforms/test_remove_isolated_nodes.py | 51 +++++++++++++--- .../transforms/remove_isolated_nodes.py | 59 +++++++++++++++---- 2 files changed, 89 insertions(+), 21 deletions(-) diff --git a/test/transforms/test_remove_isolated_nodes.py b/test/transforms/test_remove_isolated_nodes.py index 5040a517d52c..fbaed4a4ec96 100644 --- a/test/transforms/test_remove_isolated_nodes.py +++ b/test/transforms/test_remove_isolated_nodes.py @@ -1,18 +1,51 @@ import torch -from torch_geometric.data import Data +from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import RemoveIsolatedNodes def test_remove_isolated_nodes(): - assert RemoveIsolatedNodes().__repr__() == 'RemoveIsolatedNodes()' + assert str(RemoveIsolatedNodes()) == 'RemoveIsolatedNodes()' + + data = Data() + data.x = torch.arange(3) + data.edge_index = torch.tensor([[0, 2], [2, 0]]) + data.edge_attr = torch.arange(2) - edge_index = torch.tensor([[0, 2, 1, 0], [2, 0, 1, 0]]) - edge_attr = torch.tensor([1, 2, 3, 4]) - x = torch.tensor([[1], [2], [3]]) - data = Data(edge_index=edge_index, edge_attr=edge_attr, x=x) data = RemoveIsolatedNodes()(data) + assert len(data) == 3 - assert data.edge_index.tolist() == [[0, 1, 0], [1, 0, 0]] - assert data.edge_attr.tolist() == [1, 2, 4] - assert data.x.tolist() == [[1], [3]] + assert data.x.tolist() == [0, 2] + assert data.edge_index.tolist() == [[0, 1], [1, 0]] + assert data.edge_attr.tolist() == [0, 1] + + +def test_remove_isolated_nodes_in_hetero_data(): + data = HeteroData() + + data['p'].x = torch.arange(6) + data['a'].x = torch.arange(6) + data['i'].num_nodes = 4 + + # isolated paper nodes: {4} + # isolated author nodes: {3, 4, 5} + # isolated institution nodes: {0, 1, 2, 3} + data['p', '1', 'p'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 3]]) + data['p', '2', 'a'].edge_index = torch.tensor([[1, 3, 5], [0, 1, 2]]) + data['p', '2', 'a'].edge_attr = torch.arange(3) + data['p', '3', 'a'].edge_index = torch.tensor([[5], [2]]) + + data = RemoveIsolatedNodes()(data) + + assert len(data) == 4 + assert data['p'].num_nodes == 5 + assert data['a'].num_nodes == 3 + assert data['i'].num_nodes == 0 + + assert data['p'].x.tolist() == [0, 1, 2, 3, 5] + assert data['a'].x.tolist() == [0, 1, 2] + + assert data['1'].edge_index.tolist() == [[0, 1, 2], [0, 1, 3]] + assert data['2'].edge_index.tolist() == [[1, 3, 4], [0, 1, 2]] + assert data['2'].edge_attr.tolist() == [0, 1, 2] + assert data['3'].edge_index.tolist() == [[4], [2]] diff --git a/torch_geometric/transforms/remove_isolated_nodes.py b/torch_geometric/transforms/remove_isolated_nodes.py index 8c6969bdeeb2..bda19f68b8e7 100644 --- a/torch_geometric/transforms/remove_isolated_nodes.py +++ b/torch_geometric/transforms/remove_isolated_nodes.py @@ -1,28 +1,63 @@ -import re +from collections import defaultdict +from typing import Union import torch +from torch_geometric.data import Data, HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import remove_isolated_nodes @functional_transform('remove_isolated_nodes') class RemoveIsolatedNodes(BaseTransform): r"""Removes isolated nodes from the graph (functional name: :obj:`remove_isolated_nodes`).""" - def __call__(self, data): - num_nodes = data.num_nodes - out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes) - data.edge_index, data.edge_attr, mask = out + def __call__(self, data: Union[Data, HeteroData]): + # Gather all nodes that occur in at least one edge (across all types): + n_id_dict = defaultdict(list) + for store in data.edge_stores: + if 'edge_index' not in store: + continue + + if store._key is None: + src = dst = None + else: + src, _, dst = store._key + + n_id_dict[src].append(store.edge_index[0]) + n_id_dict[dst].append(store.edge_index[1]) + + n_id_dict = {k: torch.cat(v).unique() for k, v in n_id_dict.items()} + + n_map_dict = {} + for store in data.node_stores: + if store._key not in n_id_dict: + n_id_dict[store._key] = torch.empty((0, ), dtype=torch.long) - if hasattr(data, '__num_nodes__'): - data.num_nodes = int(mask.sum()) + idx = n_id_dict[store._key] + mapping = idx.new_zeros(data.num_nodes) + mapping[idx] = torch.arange(idx.numel(), device=mapping.device) + n_map_dict[store._key] = mapping - for key, item in data: - if bool(re.search('edge', key)): + for store in data.edge_stores: + if 'edge_index' not in store: continue - if torch.is_tensor(item) and item.size(0) == num_nodes: - data[key] = item[mask] + + if store._key is None: + src = dst = None + else: + src, _, dst = store._key + + row = n_map_dict[src][store.edge_index[0]] + col = n_map_dict[dst][store.edge_index[1]] + store.edge_index = torch.stack([row, col], dim=0) + + for store in data.node_stores: + for key, value in store.items(): + if key == 'num_nodes': + store.num_nodes = n_id_dict[store._key].numel() + + elif store.is_node_attr(key): + store[key] = value[n_id_dict[store._key]] return data