Skip to content

Matmul/generic #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions sparse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ class COO(object):
"""
__array_priority__ = 12

def __init__(self, coords, data=None, shape=None, has_duplicates=True):
def __init__(self, coords, data=None, shape=None,
has_duplicates=True, toarray_other=False):

self._toarray_other = toarray_other
if data is None:
# {(i, j, k): x, (i, j, k): y, ...}
if isinstance(coords, dict):
Expand Down Expand Up @@ -325,8 +328,14 @@ def transpose(self, axes=None):
def T(self):
return self.transpose(list(range(self.ndim))[::-1])

def dot(self, other):
return dot(self, other)
def dot(self, other, make_array=False):
return dot(self, other,
make_array=make_array or self._toarray_other)

__matmul__ = dot

def __rmatmul__(self, other):
return dot(other, self)

def reshape(self, shape):
if self.shape == shape:
Expand Down Expand Up @@ -680,7 +689,12 @@ def tensordot(a, b, axes=2):
return res.reshape(olda + oldb)


def dot(a, b):
def dot(a, b, make_array=True):
if make_array:
if not hasattr(a, 'ndim'):
a = np.array(a)
if not hasattr(b, 'ndim'):
b = np.array(b)
return tensordot(a, b, axes=((a.ndim - 1,), (b.ndim - 2,)))


Expand Down
41 changes: 41 additions & 0 deletions sparse/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

import sys
import random
import operator
import numpy as np
Expand Down Expand Up @@ -130,12 +131,52 @@ def test_dot():
a = random_x((3, 4, 5))
b = random_x((5, 6))

la = a.tolist()
lb = b.tolist()
la, lb # silencing flake8

sa = COO.from_numpy(a)
sb = COO.from_numpy(b)

assert_eq(a.dot(b), sa.dot(sb))
assert_eq(np.dot(a, b), sparse.dot(sa, sb))

if sys.version_info >= (3, 5):
# Coerce to np.array for arg lacking __matmul__
sa._toarray_other = True
sb._toarray_other = True

# Basic equivalences
assert_eq(eval("a @ b"), eval("sa @ sb"))
assert_eq(eval("sa @ sb"), sparse.dot(sa, sb))

# Exercise __rmatmul__ with naive collection (list)
assert_eq(eval("la @ b"), eval("la @ sb"))
assert_eq(eval("a @ sb"), sparse.dot(a, sb))
assert_eq(eval("a @ lb"), eval("sa @ lb"))

# Test that SOO's and np.array's combine correctly
assert_eq(eval("a @ sb"), eval("sa @ b"))


@pytest.mark.xfail
def test_dot_nocoercion():
a = random_x((3, 4, 5))
b = random_x((5, 6))

la = a.tolist()
lb = b.tolist()
la, lb # silencing flake8

sa = COO.from_numpy(a)
sb = COO.from_numpy(b)
sa, sb # silencing flake8

if sys.version_info >= (3, 5):
# Operations with naive collection (list)
assert_eq(eval("la @ b"), eval("la @ sb"))
assert_eq(eval("a @ lb"), eval("sa @ lb"))


@pytest.mark.parametrize('func', [np.expm1, np.log1p, np.sin, np.tan,
np.sinh, np.tanh, np.floor, np.ceil,
Expand Down