Skip to content

Commit

Permalink
Add HeteroData support to RemoveIsolatedNodes (#4479)
Browse files Browse the repository at this point in the history
* hetero isolated support

* update

* typo

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
Padarn and rusty1s authored Apr 20, 2022
1 parent 2797208 commit 5fe7077
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 21 deletions.
51 changes: 42 additions & 9 deletions test/transforms/test_remove_isolated_nodes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,51 @@
import torch

from torch_geometric.data import Data
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import RemoveIsolatedNodes


def test_remove_isolated_nodes():
assert RemoveIsolatedNodes().__repr__() == 'RemoveIsolatedNodes()'
assert str(RemoveIsolatedNodes()) == 'RemoveIsolatedNodes()'

data = Data()
data.x = torch.arange(3)
data.edge_index = torch.tensor([[0, 2], [2, 0]])
data.edge_attr = torch.arange(2)

edge_index = torch.tensor([[0, 2, 1, 0], [2, 0, 1, 0]])
edge_attr = torch.tensor([1, 2, 3, 4])
x = torch.tensor([[1], [2], [3]])
data = Data(edge_index=edge_index, edge_attr=edge_attr, x=x)
data = RemoveIsolatedNodes()(data)

assert len(data) == 3
assert data.edge_index.tolist() == [[0, 1, 0], [1, 0, 0]]
assert data.edge_attr.tolist() == [1, 2, 4]
assert data.x.tolist() == [[1], [3]]
assert data.x.tolist() == [0, 2]
assert data.edge_index.tolist() == [[0, 1], [1, 0]]
assert data.edge_attr.tolist() == [0, 1]


def test_remove_isolated_nodes_in_hetero_data():
data = HeteroData()

data['p'].x = torch.arange(6)
data['a'].x = torch.arange(6)
data['i'].num_nodes = 4

# isolated paper nodes: {4}
# isolated author nodes: {3, 4, 5}
# isolated institution nodes: {0, 1, 2, 3}
data['p', '1', 'p'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 3]])
data['p', '2', 'a'].edge_index = torch.tensor([[1, 3, 5], [0, 1, 2]])
data['p', '2', 'a'].edge_attr = torch.arange(3)
data['p', '3', 'a'].edge_index = torch.tensor([[5], [2]])

data = RemoveIsolatedNodes()(data)

assert len(data) == 4
assert data['p'].num_nodes == 5
assert data['a'].num_nodes == 3
assert data['i'].num_nodes == 0

assert data['p'].x.tolist() == [0, 1, 2, 3, 5]
assert data['a'].x.tolist() == [0, 1, 2]

assert data['1'].edge_index.tolist() == [[0, 1, 2], [0, 1, 3]]
assert data['2'].edge_index.tolist() == [[1, 3, 4], [0, 1, 2]]
assert data['2'].edge_attr.tolist() == [0, 1, 2]
assert data['3'].edge_index.tolist() == [[4], [2]]
59 changes: 47 additions & 12 deletions torch_geometric/transforms/remove_isolated_nodes.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,63 @@
import re
from collections import defaultdict
from typing import Union

import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import remove_isolated_nodes


@functional_transform('remove_isolated_nodes')
class RemoveIsolatedNodes(BaseTransform):
r"""Removes isolated nodes from the graph
(functional name: :obj:`remove_isolated_nodes`)."""
def __call__(self, data):
num_nodes = data.num_nodes
out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes)
data.edge_index, data.edge_attr, mask = out
def __call__(self, data: Union[Data, HeteroData]):
# Gather all nodes that occur in at least one edge (across all types):
n_id_dict = defaultdict(list)
for store in data.edge_stores:
if 'edge_index' not in store:
continue

if store._key is None:
src = dst = None
else:
src, _, dst = store._key

n_id_dict[src].append(store.edge_index[0])
n_id_dict[dst].append(store.edge_index[1])

n_id_dict = {k: torch.cat(v).unique() for k, v in n_id_dict.items()}

n_map_dict = {}
for store in data.node_stores:
if store._key not in n_id_dict:
n_id_dict[store._key] = torch.empty((0, ), dtype=torch.long)

if hasattr(data, '__num_nodes__'):
data.num_nodes = int(mask.sum())
idx = n_id_dict[store._key]
mapping = idx.new_zeros(data.num_nodes)
mapping[idx] = torch.arange(idx.numel(), device=mapping.device)
n_map_dict[store._key] = mapping

for key, item in data:
if bool(re.search('edge', key)):
for store in data.edge_stores:
if 'edge_index' not in store:
continue
if torch.is_tensor(item) and item.size(0) == num_nodes:
data[key] = item[mask]

if store._key is None:
src = dst = None
else:
src, _, dst = store._key

row = n_map_dict[src][store.edge_index[0]]
col = n_map_dict[dst][store.edge_index[1]]
store.edge_index = torch.stack([row, col], dim=0)

for store in data.node_stores:
for key, value in store.items():
if key == 'num_nodes':
store.num_nodes = n_id_dict[store._key].numel()

elif store.is_node_attr(key):
store[key] = value[n_id_dict[store._key]]

return data

0 comments on commit 5fe7077

Please sign in to comment.