Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RemoveSelfLoops transformation #9562

Merged
merged 12 commits into from
Aug 14, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `RemoveSelfLoops` transformation ([#9562](https://github.com/pyg-team/pytorch_geometric/pull/9562))
- Added ONNX export for `scatter` with min/max reductions ([#9587](https://github.com/pyg-team/pytorch_geometric/pull/9587))
- Added a `residual` option in `GATConv` and `GATv2Conv` ([#9515](https://github.com/pyg-team/pytorch_geometric/pull/9515))
- Added the `PatchTransformerAggregation` layer ([#9487](https://github.com/pyg-team/pytorch_geometric/pull/9487))
Expand Down
45 changes: 45 additions & 0 deletions test/transforms/test_remove_self_loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import RemoveSelfLoops


def test_remove_self_loops():
assert str(RemoveSelfLoops()) == 'RemoveSelfLoops()'

assert len(RemoveSelfLoops()(Data())) == 0

edge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]])
edge_weight = torch.tensor([1, 2, 3, 4])
edge_attr = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])

data = Data(edge_index=edge_index, num_nodes=3)
data = RemoveSelfLoops()(data)
assert len(data) == 2
assert data.edge_index.tolist() == [[1, 2], [0, 1]]
assert data.num_nodes == 3

data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)
data = RemoveSelfLoops(attr='edge_weight')(data)
assert data.edge_index.tolist() == [[1, 2], [0, 1]]
assert data.num_nodes == 3
assert data.edge_weight.tolist() == [2, 4]

data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)
data = RemoveSelfLoops(attr='edge_attr')(data)
assert data.edge_index.tolist() == [[1, 2], [0, 1]]
assert data.num_nodes == 3
assert data.edge_attr.tolist() == [[3, 4], [7, 8]]


def test_hetero_remove_self_loops():
edge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]])

data = HeteroData()
data['v'].num_nodes = 3
data['w'].num_nodes = 3
data['v', 'v'].edge_index = edge_index
data['v', 'w'].edge_index = edge_index
data = RemoveSelfLoops()(data)
assert data['v', 'v'].edge_index.tolist() == [[1, 2], [0, 1]]
assert data['v', 'w'].edge_index.tolist() == edge_index.tolist()
2 changes: 2 additions & 0 deletions torch_geometric/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .local_degree_profile import LocalDegreeProfile
from .add_self_loops import AddSelfLoops
from .add_remaining_self_loops import AddRemainingSelfLoops
from .remove_self_loops import RemoveSelfLoops
from .remove_isolated_nodes import RemoveIsolatedNodes
from .remove_duplicated_edges import RemoveDuplicatedEdges
from .knn_graph import KNNGraph
Expand Down Expand Up @@ -87,6 +88,7 @@
'LocalDegreeProfile',
'AddSelfLoops',
'AddRemainingSelfLoops',
'RemoveSelfLoops',
'RemoveIsolatedNodes',
'RemoveDuplicatedEdges',
'KNNGraph',
Expand Down
36 changes: 36 additions & 0 deletions torch_geometric/transforms/remove_self_loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Union

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import remove_self_loops


@functional_transform('remove_self_loops')
class RemoveSelfLoops(BaseTransform):
r"""Removes all self-loops in the given homogeneous or heterogeneous
graph (functional name: :obj:`remove_self_loops`).

Args:
attr (str, optional): The name of the attribute of edge weights
or multi-dimensional edge features to pass to
:meth:`torch_geometric.utils.remove_self_loops`.
(default: :obj:`"edge_weight"`)
"""
def __init__(self, attr: str = 'edge_weight') -> None:
self.attr = attr

def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.edge_stores:
if store.is_bipartite() or 'edge_index' not in store:
continue

store.edge_index, store[self.attr] = remove_self_loops(
store.edge_index,
edge_attr=store.get(self.attr, None),
)

return data
Loading