diff --git a/CHANGELOG.md b/CHANGELOG.md index 49e88c7d2318..28f5f9e79aed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for `EdgeIndex.unbind` ([#9298](https://github.com/pyg-team/pytorch_geometric/pull/9298)) - Integrate `torch_geometric.Index` into `torch_geometric.EdgeIndex` ([#9296](https://github.com/pyg-team/pytorch_geometric/pull/9296)) - Support `EdgeIndex.sparse_narrow` for non-sorted edge indices ([#9291](https://github.com/pyg-team/pytorch_geometric/pull/9291)) - Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289), [#9297](https://github.com/pyg-team/pytorch_geometric/pull/9297)) diff --git a/test/test_edge_index.py b/test/test_edge_index.py index e2a8314d4b93..399e14b0a89f 100644 --- a/test/test_edge_index.py +++ b/test/test_edge_index.py @@ -615,6 +615,33 @@ def test_select(dtype, device): assert out._indptr is None +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +def test_unbind(dtype, device): + kwargs = dict(dtype=dtype, device=device) + + adj = EdgeIndex( + [[0, 1, 1, 2], [1, 0, 2, 1]], + sort_order='row', + sparse_size=(4, 5), + **kwargs, + ).fill_cache_() + + row, col = adj + + assert isinstance(row, Index) + assert row.equal(tensor([0, 1, 1, 2], device=device)) + assert row.dim_size == 4 + assert row.is_sorted + assert row._indptr.equal(tensor([0, 1, 3, 4, 4], device=device)) + + assert isinstance(col, Index) + assert col.equal(tensor([1, 0, 2, 1], device=device)) + assert col.dim_size == 5 + assert not col.is_sorted + assert col._indptr is None + + @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('value_dtype', [None, torch.double]) diff --git a/torch_geometric/edge_index.py b/torch_geometric/edge_index.py index da1a1b30ab25..6b4ccf860092 100644 --- a/torch_geometric/edge_index.py +++ b/torch_geometric/edge_index.py @@ -1542,6 +1542,22 @@ def _select(input: EdgeIndex, dim: int, index: int) -> Union[Tensor, Index]: return out +@implements(aten.unbind.int) +def _unbind( + input: EdgeIndex, + dim: int = 0, +) -> Union[List[Index], List[Tensor]]: + + if dim == 0 or dim == -2: + row = input[0] + assert isinstance(row, Index) + col = input[1] + assert isinstance(col, Index) + return [row, col] + + return aten.unbind.int(input._data, dim) + + @implements(aten.add.Tensor) def _add( input: EdgeIndex,