Skip to content

Commit

Permalink
Add is_sparse and to_torch_coo_tensor (pyg-team#6003)
Browse files Browse the repository at this point in the history
This PR aims to
+ Add `is_sparse` to check if the input is either `torch.sparse.Tensor`
or `torch_sparse.SparseTensor`
+ Add `to_torch_coo_tensor` to convert `edge_index` and `edge_weight` to
`torch.sparse.Tensor` (in COO format)
+ Move `is_torch_sparse_tensor` from `torch_sparse_tensor.py` to
`sparse.py`. I think this makes the structure in `torch_geometric.utils`
clearer.
+ Clean up duplicated code in `message_passing.py`.

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
2 people authored and JakubPietrakIntel committed Nov 25, 2022
1 parent b4f61aa commit 2e3edf9
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 52 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Add `to_fixed_size` graph transformer ([#5939](https://github.com/pyg-team/pytorch_geometric/pull/5939))
- Add support for symbolic tracing of `SchNet` model ([#5938](https://github.com/pyg-team/pytorch_geometric/pull/5938))
- Add support for customizable interaction graph in `SchNet` model ([#5919](https://github.com/pyg-team/pytorch_geometric/pull/5919))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003))
- Added `HydroNet` water cluster dataset ([#5537](https://github.com/pyg-team/pytorch_geometric/pull/5537), [#5902](https://github.com/pyg-team/pytorch_geometric/pull/5902), [#5903](https://github.com/pyg-team/pytorch_geometric/pull/5903))
- Added explainability support for heterogeneous GNNs ([#5886](https://github.com/pyg-team/pytorch_geometric/pull/5886))
- Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888))
Expand Down
56 changes: 55 additions & 1 deletion test/utils/test_sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.testing import is_full_test
from torch_geometric.utils import dense_to_sparse
from torch_geometric.utils import (
dense_to_sparse,
is_sparse,
is_torch_sparse_tensor,
to_torch_coo_tensor,
)


def test_dense_to_sparse():
Expand Down Expand Up @@ -35,3 +41,51 @@ def test_dense_to_sparse():
edge_index, edge_attr = jit(adj)
assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]]
assert edge_attr.tolist() == [3, 1, 2, 1, 2]


def test_is_torch_sparse_tensor():
x = torch.randn(5, 5)

assert not is_torch_sparse_tensor(x)
assert not is_torch_sparse_tensor(SparseTensor.from_dense(x))
assert is_torch_sparse_tensor(x.to_sparse())


def test_is_sparse():
x = torch.randn(5, 5)

assert not is_sparse(x)
assert is_sparse(SparseTensor.from_dense(x))
assert is_sparse(x.to_sparse())


def test_to_torch_coo_tensor():
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3],
[1, 0, 2, 1, 3, 2],
])
edge_attr = torch.randn(edge_index.size(1), 8)

adj = to_torch_coo_tensor(edge_index)
assert adj.size() == (4, 4)
assert adj.layout == torch.sparse_coo
assert torch.allclose(adj.indices(), edge_index)

adj = to_torch_coo_tensor(edge_index, size=6)
assert adj.size() == (6, 6)
assert adj.layout == torch.sparse_coo
assert torch.allclose(adj.indices(), edge_index)

adj = to_torch_coo_tensor(edge_index, edge_attr)
assert adj.size() == (4, 4, 8)
assert adj.layout == torch.sparse_coo
assert torch.allclose(adj.indices(), edge_index)
assert torch.allclose(adj.values(), edge_attr)

if is_full_test():
jit = torch.jit.script(to_torch_coo_tensor)
adj = jit(edge_index, edge_attr)
assert adj.size() == (4, 4, 8)
assert adj.layout == torch.sparse_coo
assert torch.allclose(adj.indices(), edge_index)
assert torch.allclose(adj.values(), edge_attr)
12 changes: 0 additions & 12 deletions test/utils/test_torch_sparse_tensor.py

This file was deleted.

29 changes: 8 additions & 21 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch_geometric.nn.aggr import Aggregation, MultiAggregation
from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
from torch_geometric.typing import Adj, Size
from torch_geometric.utils import is_torch_sparse_tensor
from torch_geometric.utils import is_sparse, is_torch_sparse_tensor

from .utils.helpers import expand_left
from .utils.inspector import Inspector, func_body_repr, func_header_repr
Expand Down Expand Up @@ -183,14 +183,15 @@ def __init__(
def __check_input__(self, edge_index, size):
the_size: List[Optional[int]] = [None, None]

if is_torch_sparse_tensor(edge_index):
if is_sparse(edge_index):
if self.flow == 'target_to_source':
raise ValueError(
('Flow direction "target_to_source" is invalid for '
'message propagation via `torch.sparse.Tensor`. If '
'you really want to make use of a reverse message '
'passing flow, pass in the transposed sparse tensor to '
'the message passing module, e.g., `adj_t.t()`.'))
'message propagation via `torch_sparse.SparseTensor` '
'or `torch.sparse.Tensor`. If you really want to make '
'use of a reverse message passing flow, pass in the '
'transposed sparse tensor to the message passing module, '
'e.g., `adj_t.t()`.'))
the_size[0] = edge_index.size(1)
the_size[1] = edge_index.size(0)
return the_size
Expand All @@ -212,18 +213,6 @@ def __check_input__(self, edge_index, size):
the_size[1] = size[1]
return the_size

elif isinstance(edge_index, SparseTensor):
if self.flow == 'target_to_source':
raise ValueError(
('Flow direction "target_to_source" is invalid for '
'message propagation via `torch_sparse.SparseTensor`. If '
'you really want to make use of a reverse message '
'passing flow, pass in the transposed sparse tensor to '
'the message passing module, e.g., `adj_t.t()`.'))
the_size[0] = edge_index.sparse_size(1)
the_size[1] = edge_index.sparse_size(0)
return the_size

raise ValueError(
('`MessagePassing.propagate` only supports integer tensors of '
'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '
Expand Down Expand Up @@ -403,9 +392,7 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
size = self.__check_input__(edge_index, size)

# Run "fused" message and aggregation (if applicable).
if ((isinstance(edge_index, SparseTensor)
or is_torch_sparse_tensor(edge_index)) and self.fuse
and not self.explain):
if is_sparse(edge_index) and self.fuse and not self.explain:
coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
size, kwargs)

Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from .mask import index_to_mask, mask_to_index
from .to_dense_batch import to_dense_batch
from .to_dense_adj import to_dense_adj
from .sparse import dense_to_sparse
from .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor,
to_torch_coo_tensor)
from .unbatch import unbatch, unbatch_edge_index
from .normalized_cut import normalized_cut
from .grid import grid
Expand All @@ -36,7 +37,6 @@
structured_negative_sampling_feasible)
from .train_test_split_edges import train_test_split_edges
from .scatter import scatter
from .torch_sparse_tensor import is_torch_sparse_tensor
from .spmm import spmm

__all__ = [
Expand Down Expand Up @@ -98,6 +98,8 @@
'train_test_split_edges',
'scatter',
'is_torch_sparse_tensor',
'is_sparse',
'to_torch_coo_tensor',
'spmm',
]

Expand Down
69 changes: 68 additions & 1 deletion torch_geometric/utils/sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Tuple
from typing import Any, Optional, Tuple, Union

import torch
from torch import Tensor
from torch_sparse import SparseTensor


def dense_to_sparse(adj: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -46,3 +47,69 @@ def dense_to_sparse(adj: Tensor) -> Tuple[Tensor, Tensor]:
row = batch + edge_index[1]
col = batch + edge_index[2]
return torch.stack([row, col], dim=0), edge_attr


def is_torch_sparse_tensor(src: Any) -> bool:
"""Returns :obj:`True` if the input :obj:`src` is a
:class:`torch.sparse.Tensor` (in any sparse layout).
Args:
src (Any): The input object to be checked.
"""
return isinstance(src, Tensor) and src.is_sparse


def is_sparse(src: Any) -> bool:
"""Returns :obj:`True` if the input :obj:`src` is of type
:class:`torch.sparse.Tensor` (in any sparse layout) or of type
:class:`torch_sparse.SparseTensor`.
Args:
src (Any): The input object to be checked.
"""
return is_torch_sparse_tensor(src) or isinstance(src, SparseTensor)


def to_torch_coo_tensor(
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
size: Optional[Union[int, Tuple[int, int]]] = None,
) -> Tensor:
"""Converts a sparse adjacency matrix defined by edge indices and edge
attributes to a :class:`torch.sparse.Tensor`.
Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): The edge attributes.
(default: :obj:`None`)
size (int or (int, int), optional): The size of the sparse matrix.
If given as an integer, will create a quadratic sparse matrix.
If set to :obj:`None`, will infer a quadratic sparse matrix based
on :obj:`edge_index.max() + 1`. (default: :obj:`None`)
:rtype: :class:`torch.sparse.FloatTensor`
Example:
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],
... [1, 0, 2, 1, 3, 2]])
>>> to_torch_coo_tensor(edge_index)
tensor(indices=tensor([[0, 1, 1, 2, 2, 3],
[1, 0, 2, 1, 3, 2]]),
values=tensor([1., 1., 1., 1., 1., 1.]),
size=(4, 4), nnz=6, layout=torch.sparse_coo)
"""
if size is None:
size = int(edge_index.max()) + 1
if not isinstance(size, (tuple, list)):
size = (size, size)

if edge_attr is None:
edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)

size = tuple(size) + edge_attr.size()[1:]
out = torch.sparse_coo_tensor(edge_index, edge_attr, size,
device=edge_index.device)
out = out.coalesce()
return out
2 changes: 1 addition & 1 deletion torch_geometric/utils/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import Tensor
from torch_sparse import SparseTensor, matmul

from .torch_sparse_tensor import is_torch_sparse_tensor
from .sparse import is_torch_sparse_tensor


@torch.jit._overload
Expand Down
13 changes: 0 additions & 13 deletions torch_geometric/utils/torch_sparse_tensor.py

This file was deleted.

0 comments on commit 2e3edf9

Please sign in to comment.