From 7d0fcd2a6206ecabfa093bc260dbff042888ec6f Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 6 May 2024 10:12:18 +0200 Subject: [PATCH] Add documentation to `torch_geometric.Index` (#9297) --- CHANGELOG.md | 2 +- docs/source/modules/root.rst | 3 +- torch_geometric/edge_index.py | 2 + torch_geometric/index.py | 72 +++++++++++++++++++++++++++++++---- 4 files changed, 69 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 31baea56f9b4..96127ec18c02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Support `EdgeIndex.sparse_narrow` for non-sorted edge indices ([#9291](https://github.com/pyg-team/pytorch_geometric/pull/9291)) -- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289)) +- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289), [#9297](https://github.com/pyg-team/pytorch_geometric/pull/9297)) - Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240)) - Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131)) - Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090)) diff --git a/docs/source/modules/root.rst b/docs/source/modules/root.rst index 079970f09c62..1ed1a2c4ea41 100644 --- a/docs/source/modules/root.rst +++ b/docs/source/modules/root.rst @@ -4,12 +4,13 @@ torch_geometric Tensor Objects -------------- -.. currentmodule:: torch_geometric.edge_index +.. currentmodule:: torch_geometric .. autosummary:: :nosignatures: :toctree: ../generated + Index EdgeIndex Functions diff --git a/torch_geometric/edge_index.py b/torch_geometric/edge_index.py index 2e750893476d..41642b11431c 100644 --- a/torch_geometric/edge_index.py +++ b/torch_geometric/edge_index.py @@ -698,6 +698,7 @@ def fill_cache_(self, no_transpose: bool = False) -> 'EdgeIndex': # Methods ################################################################# def share_memory_(self) -> 'EdgeIndex': + """""" # noqa: D419 self._data.share_memory_() if self._indptr is not None: self._indptr.share_memory_() @@ -714,6 +715,7 @@ def share_memory_(self) -> 'EdgeIndex': return self def is_shared(self) -> bool: + """""" # noqa: D419 return self._data.is_shared() def as_tensor(self) -> Tensor: diff --git a/torch_geometric/index.py b/torch_geometric/index.py index 21fe96fec922..ce13a013b469 100644 --- a/torch_geometric/index.py +++ b/torch_geometric/index.py @@ -85,7 +85,51 @@ def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any: class Index(Tensor): - r"""TODO.""" + r"""A one-dimensional :obj:`index` tensor with additional (meta)data + attached. + + :class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds + indices of shape :obj:`[num_indices]`. + + While :class:`Index` sub-classes a general :pytorch:`null` + :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*: + + * :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*, + the size of a dimension that can be indexed via :obj:`index`. + By default, it is inferred as :obj:`dim_size=index.max() + 1`. + * :obj:`is_sorted`: Whether indices are sorted in ascending order. + + Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR + conversion in case its representation is sorted. + Caches are filled based on demand (*e.g.*, when calling + :meth:`Index.get_indptr`), or when explicitly requested via + :meth:`Index.fill_cache_`, and are maintaned and adjusted over its + lifespan. + + This representation ensures for optimal computation in GNN message passing + schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG` + workflows. + + .. code-block:: python + + from torch_geometric import Index + + index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) + >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True) + assert index.dim_size == 3 + assert index.is_sorted + + # Flipping order: + edge_index.flip(0) + >>> Index([[2, 1, 1, 0], dim_size=3) + assert not index.is_sorted + + # Filtering: + mask = torch.tensor([True, True, True, False]) + index[:, mask] + >>> Index([[0, 1, 1], dim_size=3, is_sorted=True) + assert index.is_sorted + """ # See "https://pytorch.org/docs/stable/notes/extending.html" # for a basic tutorial on how to subclass `torch.Tensor`. @@ -166,7 +210,13 @@ def __new__( # Validation ############################################################## def validate(self) -> 'Index': - r"""TODO.""" + r"""Validates the :class:`Index` representation. + + In particular, it ensures that + + * it only holds valid indices. + * the sort order is correctly set. + """ assert_valid_dtype(self._data) assert_one_dimensional(self._data) assert_contiguous(self._data) @@ -191,12 +241,12 @@ def validate(self) -> 'Index': @property def dim_size(self) -> Optional[int]: - r"""TODO.""" + r"""The size of the underlying sparse vector.""" return self._dim_size @property def is_sorted(self) -> bool: - r"""TODO.""" + r"""Returns whether indices are sorted in ascending order.""" return self._is_sorted @property @@ -207,7 +257,9 @@ def dtype(self) -> torch.dtype: # type: ignore # Cache Interface ######################################################### def get_dim_size(self) -> int: - r"""TODO.""" + r"""The size of the underlying sparse vector. + Automatically computed and cached when not explicitly set. + """ if self._dim_size is None: dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0 self._dim_size = dim_size @@ -216,7 +268,7 @@ def get_dim_size(self) -> int: return self._dim_size def dim_resize_(self, dim_size: Optional[int]) -> 'Index': - r"""TODO.""" + r"""Assigns or re-assigns the size of the underlying sparse vector.""" if self.is_sorted and self._indptr is not None: if dim_size is None: self._indptr = None @@ -237,7 +289,9 @@ def dim_resize_(self, dim_size: Optional[int]) -> 'Index': @assert_sorted def get_indptr(self) -> Tensor: - r"""TODO.""" + r"""Returns the compressed index representation in case :class:`Index` + is sorted. + """ if self._indptr is None: self._indptr = index2ptr(self._data, self.get_dim_size()) @@ -245,7 +299,7 @@ def get_indptr(self) -> Tensor: return self._indptr def fill_cache_(self) -> 'Index': - r"""TODO.""" + r"""Fills the cache with (meta)data information.""" self.get_dim_size() if self.is_sorted: @@ -256,12 +310,14 @@ def fill_cache_(self) -> 'Index': # Methods ################################################################# def share_memory_(self) -> 'Index': + """""" # noqa: D419 self._data.share_memory_() if self._indptr is not None: self._indptr.share_memory_() return self def is_shared(self) -> bool: + """""" # noqa: D419 return self._data.is_shared() def as_tensor(self) -> Tensor: