Skip to content

Matmul/generic2 #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sparse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class COO(object):
__array_priority__ = 12

def __init__(self, coords, data=None, shape=None, has_duplicates=True):

if data is None:
# {(i, j, k): x, (i, j, k): y, ...}
if isinstance(coords, dict):
Expand Down Expand Up @@ -328,6 +329,11 @@ def T(self):
def dot(self, other):
return dot(self, other)

__matmul__ = dot

def __rmatmul__(self, other):
return dot(other, self)

def reshape(self, shape):
if self.shape == shape:
return self
Expand Down Expand Up @@ -681,6 +687,10 @@ def tensordot(a, b, axes=2):


def dot(a, b):
if isinstance(a, (list, tuple)):
a = np.array(a)
if isinstance(b, (list, tuple)):
b = np.array(b)
return tensordot(a, b, axes=((a.ndim - 1,), (b.ndim - 2,)))


Expand Down
41 changes: 41 additions & 0 deletions sparse/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

import sys
import random
import operator
import numpy as np
Expand Down Expand Up @@ -127,15 +128,55 @@ def test_tensordot(a_shape, b_shape, axes):


def test_dot():
import operator
a = random_x((3, 4, 5))
b = random_x((5, 6))

la = a.tolist()
lb = b.tolist()
la, lb # silencing flake8

sa = COO.from_numpy(a)
sb = COO.from_numpy(b)

assert_eq(a.dot(b), sa.dot(sb))
assert_eq(np.dot(a, b), sparse.dot(sa, sb))


if hasattr(operator, 'matmul'):
# Basic equivalences
assert_eq(eval("a @ b"), eval("sa @ sb"))
assert_eq(eval("sa @ sb"), sparse.dot(sa, sb))

# Exercise __rmatmul__ with naive collection (list)
assert_eq(eval("la @ b"), eval("la @ sb"))
assert_eq(eval("a @ sb"), sparse.dot(a, sb))
assert_eq(eval("a @ lb"), eval("sa @ lb"))

# Test that SOO's and np.array's combine correctly
assert_eq(eval("a @ sb"), eval("sa @ b"))


@pytest.mark.xfail
def test_dot_nocoercion():
# Expect failure with some non-list, non-tuple collection
# that cannot be coerced straightforwardly
a = random_x((3, 4, 5))
b = random_x((5, 6))

set_a = set(a.tolist())
set_b = set(b.tolist())
set_a, set_b # silencing flake8

sa = COO.from_numpy(a)
sb = COO.from_numpy(b)
sa, sb # silencing flake8

if sys.version_info >= (3, 5):
# Operations with naive collection (list)
assert_eq(eval("set_a @ b"), eval("set_a @ sb"))
assert_eq(eval("a @ set_b"), eval("sa @ set_b"))


@pytest.mark.parametrize('func', [np.expm1, np.log1p, np.sin, np.tan,
np.sinh, np.tanh, np.floor, np.ceil,
Expand Down