diff --git a/sparse/core.py b/sparse/core.py index be9c06c7..cafd79a7 100644 --- a/sparse/core.py +++ b/sparse/core.py @@ -89,7 +89,10 @@ class COO(object): """ __array_priority__ = 12 - def __init__(self, coords, data=None, shape=None, has_duplicates=True): + def __init__(self, coords, data=None, shape=None, + has_duplicates=True, toarray_other=False): + + self._toarray_other = toarray_other if data is None: # {(i, j, k): x, (i, j, k): y, ...} if isinstance(coords, dict): @@ -325,8 +328,14 @@ def transpose(self, axes=None): def T(self): return self.transpose(list(range(self.ndim))[::-1]) - def dot(self, other): - return dot(self, other) + def dot(self, other, make_array=False): + return dot(self, other, + make_array=make_array or self._toarray_other) + + __matmul__ = dot + + def __rmatmul__(self, other): + return dot(other, self) def reshape(self, shape): if self.shape == shape: @@ -680,7 +689,12 @@ def tensordot(a, b, axes=2): return res.reshape(olda + oldb) -def dot(a, b): +def dot(a, b, make_array=True): + if make_array: + if not hasattr(a, 'ndim'): + a = np.array(a) + if not hasattr(b, 'ndim'): + b = np.array(b) return tensordot(a, b, axes=((a.ndim - 1,), (b.ndim - 2,))) diff --git a/sparse/tests/test_core.py b/sparse/tests/test_core.py index e531245f..3405b64b 100644 --- a/sparse/tests/test_core.py +++ b/sparse/tests/test_core.py @@ -1,5 +1,6 @@ import pytest +import sys import random import operator import numpy as np @@ -130,12 +131,52 @@ def test_dot(): a = random_x((3, 4, 5)) b = random_x((5, 6)) + la = a.tolist() + lb = b.tolist() + la, lb # silencing flake8 + sa = COO.from_numpy(a) sb = COO.from_numpy(b) assert_eq(a.dot(b), sa.dot(sb)) assert_eq(np.dot(a, b), sparse.dot(sa, sb)) + if sys.version_info >= (3, 5): + # Coerce to np.array for arg lacking __matmul__ + sa._toarray_other = True + sb._toarray_other = True + + # Basic equivalences + assert_eq(eval("a @ b"), eval("sa @ sb")) + assert_eq(eval("sa @ sb"), sparse.dot(sa, sb)) + + # Exercise __rmatmul__ with naive collection (list) + assert_eq(eval("la @ b"), eval("la @ sb")) + assert_eq(eval("a @ sb"), sparse.dot(a, sb)) + assert_eq(eval("a @ lb"), eval("sa @ lb")) + + # Test that SOO's and np.array's combine correctly + assert_eq(eval("a @ sb"), eval("sa @ b")) + + +@pytest.mark.xfail +def test_dot_nocoercion(): + a = random_x((3, 4, 5)) + b = random_x((5, 6)) + + la = a.tolist() + lb = b.tolist() + la, lb # silencing flake8 + + sa = COO.from_numpy(a) + sb = COO.from_numpy(b) + sa, sb # silencing flake8 + + if sys.version_info >= (3, 5): + # Operations with naive collection (list) + assert_eq(eval("la @ b"), eval("la @ sb")) + assert_eq(eval("a @ lb"), eval("sa @ lb")) + @pytest.mark.parametrize('func', [np.expm1, np.log1p, np.sin, np.tan, np.sinh, np.tanh, np.floor, np.ceil,