Skip to content

add element-wise mul for SparseTensor #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from itertools import product

import pytest
import torch
from torch_sparse import SparseTensor, mul

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_mul(dtype, device):
rowA = torch.tensor([0, 0, 1, 2, 2], device=device)
colA = torch.tensor([0, 2, 1, 0, 1], device=device)
valueA = tensor([1, 2, 4, 1, 3], dtype, device)
A = SparseTensor(row=rowA, col=colA, value=valueA)

rowB = torch.tensor([0, 0, 1, 2, 2], device=device)
colB = torch.tensor([1, 2, 2, 1, 2], device=device)
valueB = tensor([2, 3, 1, 2, 4], dtype, device)
B = SparseTensor(row=rowB, col=colB, value=valueB)

C = A * B
rowC, colC, valueC = C.coo()

assert rowC.tolist() == [0, 2]
assert colC.tolist() == [2, 1]
assert valueC.tolist() == [6, 6]

@torch.jit.script
def jit_mul(A: SparseTensor, B: SparseTensor) -> SparseTensor:
return mul(A, B)

jit_mul(A, B)
85 changes: 70 additions & 15 deletions torch_sparse/mul.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,82 @@
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor


def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
@torch.jit._overload # noqa: F811
def mul(src, other): # noqa: F811
# type: (SparseTensor, Tensor) -> SparseTensor
pass


@torch.jit._overload # noqa: F811
def mul(src, other): # noqa: F811
# type: (SparseTensor, SparseTensor) -> SparseTensor
pass


def mul(src, other): # noqa: F811
if isinstance(other, Tensor):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')

if value is not None:
value = other.to(value.dtype).mul_(value)
else:
value = other
return src.set_value(value, layout='coo')
elif isinstance(other, SparseTensor): # Element-wise
if src.is_coalesced() and other.is_coalesced():
rowA, colA, valueA = src.coo()
rowB, colB, valueB = other.coo()

row = torch.cat([rowA, rowB], dim=0)
col = torch.cat([colA, colB], dim=0)

if valueA is not None and valueB is not None:
value = torch.cat([valueA, valueB], dim=0)
else:
raise TypeError('Value of SparseTensor is None.')

M = max(src.size(0), other.size(0))
N = max(src.size(1), other.size(1))
sparse_sizes = (M, N)

# Sort indices
idx = col.new_full((col.numel() + 1,), -1)
idx[1:] = row * sparse_sizes[1] + col
perm = idx[1:].argsort()
row, col, value = row[perm], col[perm], value[perm]

idx[1:] = idx[1:][perm]
mask = idx[1:] > idx[:-1]
# Skip if indices are already coalesced (no-overlaps).
if mask.all():
return SparseTensor(row=row, col=col, value=torch.zeros(len(value), dtype=value.dtype),
sparse_sizes=sparse_sizes)

rmask = ~mask
ridx = rmask.nonzero().flatten()

out = SparseTensor(row=row[rmask], col=col[rmask], value=value[ridx - 1] * value[ridx],
sparse_sizes=sparse_sizes)
return out
else:
raise ValueError('SparseTensor is not coalesced.')

if value is not None:
value = other.to(value.dtype).mul_(value)
else:
value = other
return src.set_value(value, layout='coo')
raise NotImplementedError


def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: