diff --git a/sparse/core.py b/sparse/core.py index e363db34..43813bfe 100644 --- a/sparse/core.py +++ b/sparse/core.py @@ -375,6 +375,21 @@ def T(self): def dot(self, other): return dot(self, other) + def __matmul__(self, other): + try: + return dot(self, other) + except NotImplementedError: + return NotImplemented + + def __rmatmul__(self, other): + try: + return dot(other, self) + except NotImplementedError: + return NotImplemented + + def __numpy_ufunc__(self, ufunc, method, i, inputs, **kwargs): + return NotImplemented + def linear_loc(self, signed=False): """ Index location of every piece of data in a flattened array @@ -802,6 +817,10 @@ def tensordot(a, b, axes=2): def dot(a, b): + if not hasattr(a, 'ndim') or not hasattr(b, 'ndim'): + raise NotImplementedError( + "Cannot perform dot product on types %s, %s" % + (type(a), type(b))) return tensordot(a, b, axes=((a.ndim - 1,), (b.ndim - 2,))) diff --git a/sparse/tests/test_core.py b/sparse/tests/test_core.py index d6c4f171..51af459c 100644 --- a/sparse/tests/test_core.py +++ b/sparse/tests/test_core.py @@ -127,6 +127,7 @@ def test_tensordot(a_shape, b_shape, axes): def test_dot(): + import operator a = random_x((3, 4, 5)) b = random_x((5, 6)) @@ -136,6 +137,33 @@ def test_dot(): assert_eq(a.dot(b), sa.dot(sb)) assert_eq(np.dot(a, b), sparse.dot(sa, sb)) + if hasattr(operator, 'matmul'): + # Basic equivalences + assert_eq(eval("a @ b"), eval("sa @ sb")) + assert_eq(eval("sa @ sb"), sparse.dot(sa, sb)) + + # Test that SOO's and np.array's combine correctly + assert_eq(eval("a @ sb"), eval("sa @ b")) + + +@pytest.mark.xfail +def test_dot_nocoercion(): + a = random_x((3, 4, 5)) + b = random_x((5, 6)) + + la = a.tolist() + lb = b.tolist() + la, lb # silencing flake8 + + sa = COO.from_numpy(a) + sb = COO.from_numpy(b) + sa, sb # silencing flake8 + + if hasattr(operator, 'matmul'): + # Operations with naive collection (list) + assert_eq(eval("la @ b"), eval("la @ sb")) + assert_eq(eval("a @ lb"), eval("sa @ lb")) + @pytest.mark.parametrize('func', [np.expm1, np.log1p, np.sin, np.tan, np.sinh, np.tanh, np.floor, np.ceil,