Skip to content

Commit

Permalink
Add documentation to torch_geometric.Index (#9297)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored May 6, 2024
1 parent 8bb44ed commit 7d0fcd2
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Support `EdgeIndex.sparse_narrow` for non-sorted edge indices ([#9291](https://github.com/pyg-team/pytorch_geometric/pull/9291))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289), [#9297](https://github.com/pyg-team/pytorch_geometric/pull/9297))
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))
Expand Down
3 changes: 2 additions & 1 deletion docs/source/modules/root.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ torch_geometric
Tensor Objects
--------------

.. currentmodule:: torch_geometric.edge_index
.. currentmodule:: torch_geometric

.. autosummary::
:nosignatures:
:toctree: ../generated

Index
EdgeIndex

Functions
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ def fill_cache_(self, no_transpose: bool = False) -> 'EdgeIndex':
# Methods #################################################################

def share_memory_(self) -> 'EdgeIndex':
"""""" # noqa: D419
self._data.share_memory_()
if self._indptr is not None:
self._indptr.share_memory_()
Expand All @@ -714,6 +715,7 @@ def share_memory_(self) -> 'EdgeIndex':
return self

def is_shared(self) -> bool:
"""""" # noqa: D419
return self._data.is_shared()

def as_tensor(self) -> Tensor:
Expand Down
72 changes: 64 additions & 8 deletions torch_geometric/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,51 @@ def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:


class Index(Tensor):
r"""TODO."""
r"""A one-dimensional :obj:`index` tensor with additional (meta)data
attached.
:class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds
indices of shape :obj:`[num_indices]`.
While :class:`Index` sub-classes a general :pytorch:`null`
:class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:
* :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*,
the size of a dimension that can be indexed via :obj:`index`.
By default, it is inferred as :obj:`dim_size=index.max() + 1`.
* :obj:`is_sorted`: Whether indices are sorted in ascending order.
Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR
conversion in case its representation is sorted.
Caches are filled based on demand (*e.g.*, when calling
:meth:`Index.get_indptr`), or when explicitly requested via
:meth:`Index.fill_cache_`, and are maintaned and adjusted over its
lifespan.
This representation ensures for optimal computation in GNN message passing
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
workflows.
.. code-block:: python
from torch_geometric import Index
index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
>>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
assert index.dim_size == 3
assert index.is_sorted
# Flipping order:
edge_index.flip(0)
>>> Index([[2, 1, 1, 0], dim_size=3)
assert not index.is_sorted
# Filtering:
mask = torch.tensor([True, True, True, False])
index[:, mask]
>>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
assert index.is_sorted
"""
# See "https://pytorch.org/docs/stable/notes/extending.html"
# for a basic tutorial on how to subclass `torch.Tensor`.

Expand Down Expand Up @@ -166,7 +210,13 @@ def __new__(
# Validation ##############################################################

def validate(self) -> 'Index':
r"""TODO."""
r"""Validates the :class:`Index` representation.
In particular, it ensures that
* it only holds valid indices.
* the sort order is correctly set.
"""
assert_valid_dtype(self._data)
assert_one_dimensional(self._data)
assert_contiguous(self._data)
Expand All @@ -191,12 +241,12 @@ def validate(self) -> 'Index':

@property
def dim_size(self) -> Optional[int]:
r"""TODO."""
r"""The size of the underlying sparse vector."""
return self._dim_size

@property
def is_sorted(self) -> bool:
r"""TODO."""
r"""Returns whether indices are sorted in ascending order."""
return self._is_sorted

@property
Expand All @@ -207,7 +257,9 @@ def dtype(self) -> torch.dtype: # type: ignore
# Cache Interface #########################################################

def get_dim_size(self) -> int:
r"""TODO."""
r"""The size of the underlying sparse vector.
Automatically computed and cached when not explicitly set.
"""
if self._dim_size is None:
dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0
self._dim_size = dim_size
Expand All @@ -216,7 +268,7 @@ def get_dim_size(self) -> int:
return self._dim_size

def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
r"""TODO."""
r"""Assigns or re-assigns the size of the underlying sparse vector."""
if self.is_sorted and self._indptr is not None:
if dim_size is None:
self._indptr = None
Expand All @@ -237,15 +289,17 @@ def dim_resize_(self, dim_size: Optional[int]) -> 'Index':

@assert_sorted
def get_indptr(self) -> Tensor:
r"""TODO."""
r"""Returns the compressed index representation in case :class:`Index`
is sorted.
"""
if self._indptr is None:
self._indptr = index2ptr(self._data, self.get_dim_size())

assert isinstance(self._indptr, Tensor)
return self._indptr

def fill_cache_(self) -> 'Index':
r"""TODO."""
r"""Fills the cache with (meta)data information."""
self.get_dim_size()

if self.is_sorted:
Expand All @@ -256,12 +310,14 @@ def fill_cache_(self) -> 'Index':
# Methods #################################################################

def share_memory_(self) -> 'Index':
"""""" # noqa: D419
self._data.share_memory_()
if self._indptr is not None:
self._indptr.share_memory_()
return self

def is_shared(self) -> bool:
"""""" # noqa: D419
return self._data.is_shared()

def as_tensor(self) -> Tensor:
Expand Down

0 comments on commit 7d0fcd2

Please sign in to comment.