diff --git a/sparse/core.py b/sparse/core.py index be9c06c7..0ef4749a 100644 --- a/sparse/core.py +++ b/sparse/core.py @@ -90,6 +90,7 @@ class COO(object): __array_priority__ = 12 def __init__(self, coords, data=None, shape=None, has_duplicates=True): + if data is None: # {(i, j, k): x, (i, j, k): y, ...} if isinstance(coords, dict): @@ -328,6 +329,11 @@ def T(self): def dot(self, other): return dot(self, other) + __matmul__ = dot + + def __rmatmul__(self, other): + return dot(other, self) + def reshape(self, shape): if self.shape == shape: return self @@ -681,6 +687,10 @@ def tensordot(a, b, axes=2): def dot(a, b): + if isinstance(a, (list, tuple)): + a = np.array(a) + if isinstance(b, (list, tuple)): + b = np.array(b) return tensordot(a, b, axes=((a.ndim - 1,), (b.ndim - 2,))) diff --git a/sparse/tests/test_core.py b/sparse/tests/test_core.py index e531245f..d0238d98 100644 --- a/sparse/tests/test_core.py +++ b/sparse/tests/test_core.py @@ -1,5 +1,6 @@ import pytest +import sys import random import operator import numpy as np @@ -127,15 +128,55 @@ def test_tensordot(a_shape, b_shape, axes): def test_dot(): + import operator a = random_x((3, 4, 5)) b = random_x((5, 6)) + la = a.tolist() + lb = b.tolist() + la, lb # silencing flake8 + sa = COO.from_numpy(a) sb = COO.from_numpy(b) assert_eq(a.dot(b), sa.dot(sb)) assert_eq(np.dot(a, b), sparse.dot(sa, sb)) + + if hasattr(operator, 'matmul'): + # Basic equivalences + assert_eq(eval("a @ b"), eval("sa @ sb")) + assert_eq(eval("sa @ sb"), sparse.dot(sa, sb)) + + # Exercise __rmatmul__ with naive collection (list) + assert_eq(eval("la @ b"), eval("la @ sb")) + assert_eq(eval("a @ sb"), sparse.dot(a, sb)) + assert_eq(eval("a @ lb"), eval("sa @ lb")) + + # Test that SOO's and np.array's combine correctly + assert_eq(eval("a @ sb"), eval("sa @ b")) + + +@pytest.mark.xfail +def test_dot_nocoercion(): + # Expect failure with some non-list, non-tuple collection + # that cannot be coerced straightforwardly + a = random_x((3, 4, 5)) + b = random_x((5, 6)) + + set_a = set(a.tolist()) + set_b = set(b.tolist()) + set_a, set_b # silencing flake8 + + sa = COO.from_numpy(a) + sb = COO.from_numpy(b) + sa, sb # silencing flake8 + + if sys.version_info >= (3, 5): + # Operations with naive collection (list) + assert_eq(eval("set_a @ b"), eval("set_a @ sb")) + assert_eq(eval("a @ set_b"), eval("sa @ set_b")) + @pytest.mark.parametrize('func', [np.expm1, np.log1p, np.sin, np.tan, np.sinh, np.tanh, np.floor, np.ceil,