Skip to content

Commit

Permalink
Add Index.add and Index.sub (#9289)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored May 4, 2024
1 parent 9a2e35e commit ebabce5
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289))
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))
Expand Down
1 change: 0 additions & 1 deletion test/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,6 @@ def test_data_loader(dtype, num_workers):
assert len(batch) == 2
for adj in batch:
assert isinstance(adj, EdgeIndex)
assert adj.dtype == adj.dtype
assert adj.is_shared() == (num_workers > 0)
assert adj._data.is_shared() == (num_workers > 0)
assert adj._indptr.is_shared() == (num_workers > 0)
Expand Down
131 changes: 131 additions & 0 deletions test/test_index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os.path as osp
from typing import List

import pytest
import torch
from torch import tensor
Expand Down Expand Up @@ -414,3 +417,131 @@ def test_getitem(dtype, device):
out = tmp[index]
assert not isinstance(out, Index)
assert out.equal(tmp[index.as_tensor()])


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
def test_add(dtype, device):
kwargs = dict(dtype=dtype, device=device)
index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)

out = torch.add(index, 2, alpha=2)
assert isinstance(out, Index)
assert out.equal(tensor([4, 5, 5, 6], device=device))
assert out.dim_size == 7
assert out.is_sorted

out = index + torch.tensor([2], dtype=dtype, device=device)
assert isinstance(out, Index)
assert out.equal(tensor([2, 3, 3, 4], device=device))
assert out.dim_size == 5
assert out.is_sorted

out = index.add(index)
assert isinstance(out, Index)
assert out.equal(tensor([0, 2, 2, 4], device=device))
assert out.dim_size == 6
assert not out.is_sorted

index += 2
assert isinstance(index, Index)
assert index.equal(tensor([2, 3, 3, 4], device=device))
assert index.dim_size == 5
assert index.is_sorted

with pytest.raises(RuntimeError, match="can't be cast"):
index += 2.5


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
def test_sub(dtype, device):
kwargs = dict(dtype=dtype, device=device)
index = Index([4, 5, 5, 6], dim_size=7, is_sorted=True, **kwargs)

out = torch.sub(index, 2, alpha=2)
assert isinstance(out, Index)
assert out.equal(tensor([0, 1, 1, 2], device=device))
assert out.dim_size == 3
assert out.is_sorted

out = index - torch.tensor([2], dtype=dtype, device=device)
assert isinstance(out, Index)
assert out.equal(tensor([2, 3, 3, 4], device=device))
assert out.dim_size == 5
assert out.is_sorted

out = index.sub(index)
assert isinstance(out, Index)
assert out.equal(tensor([0, 0, 0, 0], device=device))
assert out.dim_size is None
assert not out.is_sorted

index -= 2
assert isinstance(index, Index)
assert index.equal(tensor([2, 3, 3, 4], device=device))
assert index.dim_size == 5
assert not out.is_sorted

with pytest.raises(RuntimeError, match="can't be cast"):
index -= 2.5


def test_to_list():
index = Index([0, 1, 1, 2])
with pytest.raises(RuntimeError, match="supported for tensor subclasses"):
index.tolist()


def test_numpy():
index = Index([0, 1, 1, 2])
with pytest.raises(RuntimeError, match="supported for tensor subclasses"):
index.numpy()


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
def test_save_and_load(dtype, device, tmp_path):
kwargs = dict(dtype=dtype, device=device)
index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)
index.fill_cache_()

path = osp.join(tmp_path, 'edge_index.pt')
torch.save(index, path)
out = torch.load(path)

assert isinstance(out, Index)
assert out.equal(index)
assert out.dim_size == 3
assert out.is_sorted
assert out._indptr.equal(index._indptr)


def _collate_fn(indices: List[Index]) -> List[Index]:
return indices


@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('num_workers', [0, 2])
def test_data_loader(dtype, num_workers):
kwargs = dict(dtype=dtype)
index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)
index.fill_cache_()

loader = torch.utils.data.DataLoader(
[index] * 4,
batch_size=2,
num_workers=num_workers,
collate_fn=_collate_fn,
drop_last=True,
)

assert len(loader) == 2
for batch in loader:
assert isinstance(batch, list)
assert len(batch) == 2
for index in batch:
assert isinstance(index, Index)
assert index.is_shared() == (num_workers > 0)
assert index._data.is_shared() == (num_workers > 0)
assert index._indptr.is_shared() == (num_workers > 0)
137 changes: 137 additions & 0 deletions torch_geometric/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ def _shallow_copy(self) -> 'Index':
out._cat_metadata = self._cat_metadata
return out

def _clear_metadata(self) -> 'Index':
self._dim_size = None
self._is_sorted = False
self._indptr = None
self._cat_metadata = None
return self


def apply_(
tensor: Index,
Expand Down Expand Up @@ -598,3 +605,133 @@ def _index(
out._dim_size = input.dim_size

return out


@implements(aten.add.Tensor)
def _add(
input: Index,
other: Union[int, Tensor, Index],
*,
alpha: int = 1,
) -> Union[Index, Tensor]:

data = aten.add.Tensor(
input._data,
other._data if isinstance(other, Index) else other,
alpha=alpha,
)

if data.dtype not in INDEX_DTYPES:
return data
if data.dim() != 1:
return data

out = Index(data)

if isinstance(other, Tensor) and other.numel() <= 1:
other = int(other)

if isinstance(other, int):
if input.dim_size is not None:
out._dim_size = input.dim_size + alpha * other
out._is_sorted = input.is_sorted

elif isinstance(other, Index):
if input.dim_size is not None and other.dim_size is not None:
out._dim_size = input.dim_size + alpha * other.dim_size

return out


@implements(aten.add_.Tensor)
def add_(
input: Index,
other: Union[int, Tensor, Index],
*,
alpha: int = 1,
) -> Index:

dim_size = input.dim_size
is_sorted = input.is_sorted
input._clear_metadata()

aten.add_.Tensor(
input._data,
other._data if isinstance(other, Index) else other,
alpha=alpha,
)

if isinstance(other, Tensor) and other.numel() <= 1:
other = int(other)

if isinstance(other, int):
if dim_size is not None:
input._dim_size = dim_size + alpha * other
input._is_sorted = is_sorted

elif isinstance(other, Index):
if dim_size is not None and other.dim_size is not None:
input._dim_size = dim_size + alpha * other.dim_size

return input


@implements(aten.sub.Tensor)
def _sub(
input: Index,
other: Union[int, Tensor, Index],
*,
alpha: int = 1,
) -> Union[Index, Tensor]:

data = aten.sub.Tensor(
input._data,
other._data if isinstance(other, Index) else other,
alpha=alpha,
)

if data.dtype not in INDEX_DTYPES:
return data
if data.dim() != 1:
return data

out = Index(data)

if isinstance(other, Tensor) and other.numel() <= 1:
other = int(other)

if isinstance(other, int):
if input.dim_size is not None:
out._dim_size = input.dim_size - alpha * other
out._is_sorted = input.is_sorted

return out


@implements(aten.sub_.Tensor)
def sub_(
input: Index,
other: Union[int, Tensor, Index],
*,
alpha: int = 1,
) -> Index:

dim_size = input.dim_size
is_sorted = input.is_sorted
input._clear_metadata()

aten.sub_.Tensor(
input._data,
other._data if isinstance(other, Index) else other,
alpha=alpha,
)

if isinstance(other, Tensor) and other.numel() <= 1:
other = int(other)

if isinstance(other, int):
if dim_size is not None:
input._dim_size = dim_size - alpha * other
input._is_sorted = is_sorted

return input

0 comments on commit ebabce5

Please sign in to comment.