Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HeteroData support to RemoveIsolatedNodes #4479

Merged
merged 3 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion test/transforms/test_remove_isolated_nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from torch_geometric.data import Data
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import RemoveIsolatedNodes


Expand All @@ -16,3 +16,30 @@ def test_remove_isolated_nodes():
assert data.edge_index.tolist() == [[0, 1, 0], [1, 0, 0]]
assert data.edge_attr.tolist() == [1, 2, 4]
assert data.x.tolist() == [[1], [3]]

data = HeteroData()
data['paper'].x = torch.arange(6).type(torch.float)
data['author'].x = torch.arange(6).type(torch.float)

# first self loop to be isolated, second not, and third not self loop
data['paper', 'paper'].edge_index = torch.Tensor([[0, 1, 2],
[0, 1,
3]]).type(torch.long)

# remove isolation of first paper self loop, and add another
data['paper', 'author'].edge_index = torch.Tensor([[1, 3, 5],
[0, 1,
2]]).type(torch.long)

# only edge_attr on one type
data['paper', 'author'].edge_attr = torch.Tensor([1, 2,
3]).type(torch.long)

# add a duplicate (in node index) edge
data['paper', 'cites',
'author'].edge_index = torch.Tensor([[5], [2]]).type(torch.long)

data = RemoveIsolatedNodes()(data)
assert data['paper'].num_nodes == 4
assert data['author'].num_nodes == 3
assert torch.allclose(data['paper'].x, torch.Tensor([1, 2, 3, 5]))
95 changes: 85 additions & 10 deletions torch_geometric/transforms/remove_isolated_nodes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
from typing import Union

import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import remove_isolated_nodes
Expand All @@ -11,18 +13,91 @@
class RemoveIsolatedNodes(BaseTransform):
r"""Removes isolated nodes from the graph
(functional name: :obj:`remove_isolated_nodes`)."""
def __call__(self, data):
num_nodes = data.num_nodes
out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes)
data.edge_index, data.edge_attr, mask = out
def __call__(self, data: Union[Data,
HeteroData]) -> Union[Data, HeteroData]:

if hasattr(data, '__num_nodes__'):
data.num_nodes = int(mask.sum())
if isinstance(data, Data):
num_nodes = data.num_nodes
out = remove_isolated_nodes(data.edge_index, data.edge_attr,
num_nodes)
data.edge_index, data.edge_attr, mask = out

for key, item in data:
if hasattr(data, '__num_nodes__'):
data.num_nodes = int(mask.sum())

for key, item in data:
if bool(re.search('edge', key)):
continue
if torch.is_tensor(item) and item.size(0) == num_nodes:
data[key] = item[mask]

return data

elif isinstance(data, HeteroData):
return remove_hetero_isolated_nodes(data)

raise TypeError(f"`RemoveIsolatedNodes` invalid type: {type(data)}")


def remove_hetero_isolated_nodes(data: HeteroData):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
r"""Removes the isolated nodes from the heterogenous graph
given by :attr:`data`.
Self-loops are preserved onlt for non-isolated nodes. Nodes with only
a self loop are removed.

Args:
data (HeteroData): The graph to remove nodes from.

:rtype: HeteroData
"""
device = data[data.edge_types[0]].edge_index.device

for node_type in data.node_types:

num_nodes = data[node_type].num_nodes
mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)

for edge_type in data.edge_types:
if 'edge_index' in data[edge_type]:
edge_index = data[edge_type].edge_index
if edge_type[0] == edge_type[-1] == node_type:
loop_mask = torch.where(edge_index[0] != edge_index[1])
mask[edge_index[0][loop_mask[0]]] = 1
elif edge_type[0] == node_type:
mask[edge_index[0]] = 1
elif edge_type[-1] == node_type:
mask[edge_index[1]] = 1

assoc = torch.full((num_nodes, ), -1, dtype=torch.long,
device=mask.device)

assoc[mask] = torch.arange(mask.sum(), device=assoc.device)

for key, values in data[node_type].items():
if bool(re.search('edge', key)):
continue
if torch.is_tensor(item) and item.size(0) == num_nodes:
data[key] = item[mask]
if torch.is_tensor(values) and values.size(
0) == data[node_type].num_nodes:
data[node_type][key] = values[mask]

for edge_type in data.edge_types:
if 'edge_index' in data[edge_type]:
edge_index = data[edge_type].edge_index

if edge_type[0] == node_type:
edge_index[0] = assoc[edge_index[0]]
if edge_type[-1] == node_type:
edge_index[1] = assoc[edge_index[1]]

data[edge_type].edge_index = edge_index

if edge_type[0] == edge_type[-1] == node_type:
loop_mask = edge_index[0] != -1
data[edge_type].edge_index = edge_index[:, loop_mask]
if 'edge_attr' in data[edge_type]:
edge_attr = data[edge_type].edge_attr
print(edge_attr.size())
print(loop_mask.size())
data[edge_type].edge_attr = edge_attr[loop_mask]

return data
return data