diff --git a/CHANGELOG.md b/CHANGELOG.md index 38edaa119be5..901bb4c99644 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277)) +- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278)) - Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240)) - Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131)) - Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090)) diff --git a/test/test_index.py b/test/test_index.py index c9928a4a76a7..8b2b34cbb46b 100644 --- a/test/test_index.py +++ b/test/test_index.py @@ -1,5 +1,6 @@ import pytest import torch +from torch import tensor from torch_geometric import Index from torch_geometric.testing import withCUDA @@ -54,3 +55,36 @@ def test_identity(dtype, device): out = Index(index, dim_size=4, is_sorted=False) assert out.dim_size == 4 assert out.is_sorted == index.is_sorted + + +def test_validate(): + with pytest.raises(ValueError, match="unsupported data type"): + Index([0.0, 1.0]) + with pytest.raises(ValueError, match="needs to be one-dimensional"): + Index([[0], [1]]) + with pytest.raises(TypeError, match="invalid combination of arguments"): + Index(torch.tensor([0, 1]), torch.long) + with pytest.raises(TypeError, match="invalid keyword arguments"): + Index(torch.tensor([0, 1]), dtype=torch.long) + with pytest.raises(ValueError, match="contains negative indices"): + Index([-1, 0]).validate() + with pytest.raises(ValueError, match="than its registered size"): + Index([0, 10], dim_size=2).validate() + with pytest.raises(ValueError, match="not sorted"): + Index([1, 0], is_sorted=True).validate() + + +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +def test_fill_cache_(dtype, device): + kwargs = dict(dtype=dtype, device=device) + index = Index([0, 1, 1, 2], is_sorted=True, **kwargs) + index.validate().fill_cache_() + assert index.dim_size == 3 + assert index._indptr.dtype == dtype + assert index._indptr.equal(tensor([0, 1, 3, 4], device=device)) + + index = Index([1, 0, 2, 1], **kwargs) + index.validate().fill_cache_() + assert index.dim_size == 3 + assert index._indptr is None diff --git a/torch_geometric/index.py b/torch_geometric/index.py index 2382ec2ee33c..328257b7d4aa 100644 --- a/torch_geometric/index.py +++ b/torch_geometric/index.py @@ -167,6 +167,24 @@ def __new__( def validate(self) -> 'Index': r"""TODO.""" + assert_valid_dtype(self._data) + assert_one_dimensional(self._data) + assert_contiguous(self._data) + + if self.numel() > 0 and self._data.min() < 0: + raise ValueError(f"'{self.__class__.__name__}' contains negative " + f"indices (got {int(self.min())})") + + if (self.numel() > 0 and self.dim_size is not None + and self._data.max() >= self.dim_size): + raise ValueError(f"'{self.__class__.__name__}' contains larger " + f"indices than its registered size " + f"(got {int(self._data.max())}, but expected " + f"values smaller than {self.dim_size})") + + if self.is_sorted and (self._data.diff() < 0).any(): + raise ValueError(f"'{self.__class__.__name__}' is not sorted") + return self # Properties ############################################################## @@ -185,20 +203,34 @@ def is_sorted(self) -> bool: def get_dim_size(self) -> int: r"""TODO.""" - raise NotImplementedError + if self._dim_size is None: + dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0 + self._dim_size = dim_size + + assert isinstance(self._dim_size, int) + return self._dim_size def dim_resize_(self, dim_size: Optional[int]) -> 'Index': r"""TODO.""" - raise NotImplementedError + raise NotImplementedError # TODO @assert_sorted def get_indptr(self) -> Tensor: r"""TODO.""" - raise NotImplementedError + if self._indptr is None: + self._indptr = index2ptr(self._data, self.get_dim_size()) + + assert isinstance(self._indptr, Tensor) + return self._indptr def fill_cache_(self) -> 'Index': r"""TODO.""" - raise NotImplementedError + self.get_dim_size() + + if self.is_sorted: + self.get_indptr() + + return self # Methods ################################################################# @@ -219,7 +251,7 @@ def as_tensor(self) -> Tensor: def dim_narrow(self, start: Union[int, Tensor], length: int) -> 'Index': r"""TODO.""" - raise NotImplementedError + raise NotImplementedError # TODO # PyTorch/Python builtins #################################################