Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation to torch_geometric.Index #9297

Merged
merged 2 commits into from
May 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Support `EdgeIndex.sparse_narrow` for non-sorted edge indices ([#9291](https://github.com/pyg-team/pytorch_geometric/pull/9291))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289), [#9297](https://github.com/pyg-team/pytorch_geometric/pull/9297))
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))
3 changes: 2 additions & 1 deletion docs/source/modules/root.rst
Original file line number Diff line number Diff line change
@@ -4,12 +4,13 @@ torch_geometric
Tensor Objects
--------------

.. currentmodule:: torch_geometric.edge_index
.. currentmodule:: torch_geometric

.. autosummary::
:nosignatures:
:toctree: ../generated

Index
EdgeIndex

Functions
2 changes: 2 additions & 0 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
@@ -698,6 +698,7 @@ def fill_cache_(self, no_transpose: bool = False) -> 'EdgeIndex':
# Methods #################################################################

def share_memory_(self) -> 'EdgeIndex':
"""""" # noqa: D419
self._data.share_memory_()
if self._indptr is not None:
self._indptr.share_memory_()
@@ -714,6 +715,7 @@ def share_memory_(self) -> 'EdgeIndex':
return self

def is_shared(self) -> bool:
"""""" # noqa: D419
return self._data.is_shared()

def as_tensor(self) -> Tensor:
72 changes: 64 additions & 8 deletions torch_geometric/index.py
Original file line number Diff line number Diff line change
@@ -85,7 +85,51 @@ def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:


class Index(Tensor):
r"""TODO."""
r"""A one-dimensional :obj:`index` tensor with additional (meta)data
attached.

:class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds
indices of shape :obj:`[num_indices]`.

While :class:`Index` sub-classes a general :pytorch:`null`
:class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:

* :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*,
the size of a dimension that can be indexed via :obj:`index`.
By default, it is inferred as :obj:`dim_size=index.max() + 1`.
* :obj:`is_sorted`: Whether indices are sorted in ascending order.

Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR
conversion in case its representation is sorted.
Caches are filled based on demand (*e.g.*, when calling
:meth:`Index.get_indptr`), or when explicitly requested via
:meth:`Index.fill_cache_`, and are maintaned and adjusted over its
lifespan.

This representation ensures for optimal computation in GNN message passing
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
workflows.

.. code-block:: python

from torch_geometric import Index

index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
>>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
assert index.dim_size == 3
assert index.is_sorted

# Flipping order:
edge_index.flip(0)
>>> Index([[2, 1, 1, 0], dim_size=3)
assert not index.is_sorted

# Filtering:
mask = torch.tensor([True, True, True, False])
index[:, mask]
>>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
assert index.is_sorted
"""
# See "https://pytorch.org/docs/stable/notes/extending.html"
# for a basic tutorial on how to subclass `torch.Tensor`.

@@ -166,7 +210,13 @@ def __new__(
# Validation ##############################################################

def validate(self) -> 'Index':
r"""TODO."""
r"""Validates the :class:`Index` representation.

In particular, it ensures that

* it only holds valid indices.
* the sort order is correctly set.
"""
assert_valid_dtype(self._data)
assert_one_dimensional(self._data)
assert_contiguous(self._data)
@@ -191,12 +241,12 @@ def validate(self) -> 'Index':

@property
def dim_size(self) -> Optional[int]:
r"""TODO."""
r"""The size of the underlying sparse vector."""
return self._dim_size

@property
def is_sorted(self) -> bool:
r"""TODO."""
r"""Returns whether indices are sorted in ascending order."""
return self._is_sorted

@property
@@ -207,7 +257,9 @@ def dtype(self) -> torch.dtype: # type: ignore
# Cache Interface #########################################################

def get_dim_size(self) -> int:
r"""TODO."""
r"""The size of the underlying sparse vector.
Automatically computed and cached when not explicitly set.
"""
if self._dim_size is None:
dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0
self._dim_size = dim_size
@@ -216,7 +268,7 @@ def get_dim_size(self) -> int:
return self._dim_size

def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
r"""TODO."""
r"""Assigns or re-assigns the size of the underlying sparse vector."""
if self.is_sorted and self._indptr is not None:
if dim_size is None:
self._indptr = None
@@ -237,15 +289,17 @@ def dim_resize_(self, dim_size: Optional[int]) -> 'Index':

@assert_sorted
def get_indptr(self) -> Tensor:
r"""TODO."""
r"""Returns the compressed index representation in case :class:`Index`
is sorted.
"""
if self._indptr is None:
self._indptr = index2ptr(self._data, self.get_dim_size())

assert isinstance(self._indptr, Tensor)
return self._indptr

def fill_cache_(self) -> 'Index':
r"""TODO."""
r"""Fills the cache with (meta)data information."""
self.get_dim_size()

if self.is_sorted:
@@ -256,12 +310,14 @@ def fill_cache_(self) -> 'Index':
# Methods #################################################################

def share_memory_(self) -> 'Index':
"""""" # noqa: D419
self._data.share_memory_()
if self._indptr is not None:
self._indptr.share_memory_()
return self

def is_shared(self) -> bool:
"""""" # noqa: D419
return self._data.is_shared()

def as_tensor(self) -> Tensor: