Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HeteroData support to RemoveIsolatedNodes #4479

Merged
merged 3 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions test/transforms/test_remove_isolated_nodes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,51 @@
import torch

from torch_geometric.data import Data
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import RemoveIsolatedNodes


def test_remove_isolated_nodes():
assert RemoveIsolatedNodes().__repr__() == 'RemoveIsolatedNodes()'
assert str(RemoveIsolatedNodes()) == 'RemoveIsolatedNodes()'

data = Data()
data.x = torch.arange(3)
data.edge_index = torch.tensor([[0, 2], [2, 0]])
data.edge_attr = torch.arange(2)

edge_index = torch.tensor([[0, 2, 1, 0], [2, 0, 1, 0]])
edge_attr = torch.tensor([1, 2, 3, 4])
x = torch.tensor([[1], [2], [3]])
data = Data(edge_index=edge_index, edge_attr=edge_attr, x=x)
data = RemoveIsolatedNodes()(data)

assert len(data) == 3
assert data.edge_index.tolist() == [[0, 1, 0], [1, 0, 0]]
assert data.edge_attr.tolist() == [1, 2, 4]
assert data.x.tolist() == [[1], [3]]
assert data.x.tolist() == [0, 2]
assert data.edge_index.tolist() == [[0, 1], [1, 0]]
assert data.edge_attr.tolist() == [0, 1]


def test_remove_isolated_nodes_in_hetero_data():
data = HeteroData()

data['p'].x = torch.arange(6)
data['a'].x = torch.arange(6)
data['i'].num_nodes = 4

# isolated paper nodes: {4}
# isolated author nodes: {3, 4, 5}
# isolated institution nodes: {0, 1, 2, 3}
data['p', '1', 'p'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 3]])
data['p', '2', 'a'].edge_index = torch.tensor([[1, 3, 5], [0, 1, 2]])
data['p', '2', 'a'].edge_attr = torch.arange(3)
data['p', '3', 'a'].edge_index = torch.tensor([[5], [2]])

data = RemoveIsolatedNodes()(data)

assert len(data) == 4
assert data['p'].num_nodes == 5
assert data['a'].num_nodes == 3
assert data['i'].num_nodes == 0

assert data['p'].x.tolist() == [0, 1, 2, 3, 5]
assert data['a'].x.tolist() == [0, 1, 2]

assert data['1'].edge_index.tolist() == [[0, 1, 2], [0, 1, 3]]
assert data['2'].edge_index.tolist() == [[1, 3, 4], [0, 1, 2]]
assert data['2'].edge_attr.tolist() == [0, 1, 2]
assert data['3'].edge_index.tolist() == [[4], [2]]
59 changes: 47 additions & 12 deletions torch_geometric/transforms/remove_isolated_nodes.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,63 @@
import re
from collections import defaultdict
from typing import Union

import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import remove_isolated_nodes


@functional_transform('remove_isolated_nodes')
class RemoveIsolatedNodes(BaseTransform):
r"""Removes isolated nodes from the graph
(functional name: :obj:`remove_isolated_nodes`)."""
def __call__(self, data):
num_nodes = data.num_nodes
out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes)
data.edge_index, data.edge_attr, mask = out
def __call__(self, data: Union[Data, HeteroData]):
# Gather all nodes that occur in at least one edge (across all types):
n_id_dict = defaultdict(list)
for store in data.edge_stores:
if 'edge_index' not in store:
continue

if store._key is None:
src = dst = None
else:
src, _, dst = store._key

n_id_dict[src].append(store.edge_index[0])
n_id_dict[dst].append(store.edge_index[1])

n_id_dict = {k: torch.cat(v).unique() for k, v in n_id_dict.items()}

n_map_dict = {}
for store in data.node_stores:
if store._key not in n_id_dict:
n_id_dict[store._key] = torch.empty((0, ), dtype=torch.long)

if hasattr(data, '__num_nodes__'):
data.num_nodes = int(mask.sum())
idx = n_id_dict[store._key]
mapping = idx.new_zeros(data.num_nodes)
mapping[idx] = torch.arange(idx.numel(), device=mapping.device)
n_map_dict[store._key] = mapping

for key, item in data:
if bool(re.search('edge', key)):
for store in data.edge_stores:
if 'edge_index' not in store:
continue
if torch.is_tensor(item) and item.size(0) == num_nodes:
data[key] = item[mask]

if store._key is None:
src = dst = None
else:
src, _, dst = store._key

row = n_map_dict[src][store.edge_index[0]]
col = n_map_dict[dst][store.edge_index[1]]
store.edge_index = torch.stack([row, col], dim=0)

for store in data.node_stores:
for key, value in store.items():
if key == 'num_nodes':
store.num_nodes = n_id_dict[store._key].numel()

elif store.is_node_attr(key):
store[key] = value[n_id_dict[store._key]]

return data