Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add is_sparse and to_torch_coo_tensor #6003

Merged
merged 9 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion test/utils/test_sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.testing import is_full_test
from torch_geometric.utils import dense_to_sparse
from torch_geometric.utils import (
dense_to_sparse,
is_sparse,
is_torch_sparse_tensor,
to_torch_coo_tensor,
)


def test_dense_to_sparse():
Expand Down Expand Up @@ -35,3 +41,24 @@ def test_dense_to_sparse():
edge_index, edge_attr = jit(adj)
assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]]
assert edge_attr.tolist() == [3, 1, 2, 1, 2]


def test_is_torch_sparse_tensor():
x = torch.randn(5, 5)

assert not is_torch_sparse_tensor(x)
assert not is_torch_sparse_tensor(SparseTensor.from_dense(x))
assert is_torch_sparse_tensor(x.to_sparse())


def test_is_sparse():
x = torch.randn(5, 5)

assert not is_sparse(x)
assert is_sparse(SparseTensor.from_dense(x))
assert is_sparse(x.to_sparse())


def to_torch_coo_tensor():
# TODO
pass
12 changes: 0 additions & 12 deletions test/utils/test_torch_sparse_tensor.py

This file was deleted.

6 changes: 4 additions & 2 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from .mask import index_to_mask, mask_to_index
from .to_dense_batch import to_dense_batch
from .to_dense_adj import to_dense_adj
from .sparse import dense_to_sparse
from .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor,
to_torch_coo_tensor)
from .unbatch import unbatch, unbatch_edge_index
from .normalized_cut import normalized_cut
from .grid import grid
Expand All @@ -36,7 +37,6 @@
structured_negative_sampling_feasible)
from .train_test_split_edges import train_test_split_edges
from .scatter import scatter
from .torch_sparse_tensor import is_torch_sparse_tensor
from .spmm import spmm

__all__ = [
Expand Down Expand Up @@ -98,6 +98,8 @@
'train_test_split_edges',
'scatter',
'is_torch_sparse_tensor',
'is_sparse',
'to_torch_coo_tensor',
'spmm',
]

Expand Down
65 changes: 64 additions & 1 deletion torch_geometric/utils/sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Tuple
from typing import Any, Optional, Tuple

import torch
from torch import Tensor
from torch_sparse import SparseTensor

from .num_nodes import maybe_num_nodes
rusty1s marked this conversation as resolved.
Show resolved Hide resolved


def dense_to_sparse(adj: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -46,3 +49,63 @@ def dense_to_sparse(adj: Tensor) -> Tuple[Tensor, Tensor]:
row = batch + edge_index[1]
col = batch + edge_index[2]
return torch.stack([row, col], dim=0), edge_attr


def is_torch_sparse_tensor(src: Any) -> bool:
"""Returns :obj:`True` if the input :obj:`src` is a PyTorch
:obj:`SparseTensor` (in any sparse format).

Args:
src (Any): The input object to be checked.
"""
return isinstance(src, Tensor) and src.is_sparse


def is_sparse(src: Any) -> bool:
"""Returns :obj:`True` if the input :obj:`src` is a PyTorch
:obj:`SparseTensor` (in any sparse format) or a
:obj:`torch_sparse.SparseTensor`.

Args:
src (Any): The input object to be checked.
"""
return is_torch_sparse_tensor(src) or isinstance(src, SparseTensor)


def to_torch_coo_tensor(
edge_index: Tensor,
edge_weight: Optional[Tensor] = None,
num_nodes: Optional[int] = None,
) -> Tensor:
"""Converts edge index to sparse adjacency matrix
:class:`torch.sparse.Tensor`.

Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): The edge weights.
(default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

:rtype: :class:`torch.sparse.FloatTensor`

Example:

>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],
... [1, 0, 2, 1, 3, 2]])
>>> to_torch_coo_tensor(edge_index)
tensor(indices=tensor([[0, 1, 1, 2, 2, 3],
[1, 0, 2, 1, 3, 2]]),
values=tensor([1., 1., 1., 1., 1., 1.]),
size=(4, 4), nnz=6, layout=torch.sparse_coo)

"""
num_nodes = maybe_num_nodes(edge_index, num_nodes)
device = edge_index.device
if edge_weight is None:
edge_weight = torch.ones(edge_index.size(1), device=device)

shape = torch.Size((num_nodes, num_nodes))
adj = torch.sparse_coo_tensor(edge_index, edge_weight, shape,
device=device)
return adj.coalesce()
2 changes: 1 addition & 1 deletion torch_geometric/utils/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import Tensor
from torch_sparse import SparseTensor, matmul

from .torch_sparse_tensor import is_torch_sparse_tensor
from .sparse import is_torch_sparse_tensor


@torch.jit._overload
Expand Down
13 changes: 0 additions & 13 deletions torch_geometric/utils/torch_sparse_tensor.py

This file was deleted.