Skip to content

Add @ operator (simplify) #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
May 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions sparse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,21 @@ def T(self):
def dot(self, other):
return dot(self, other)

def __matmul__(self, other):
try:
return dot(self, other)
except NotImplementedError:
return NotImplemented

def __rmatmul__(self, other):
try:
return dot(other, self)
except NotImplementedError:
return NotImplemented

def __numpy_ufunc__(self, ufunc, method, i, inputs, **kwargs):
return NotImplemented

def linear_loc(self, signed=False):
""" Index location of every piece of data in a flattened array

Expand Down Expand Up @@ -802,6 +817,10 @@ def tensordot(a, b, axes=2):


def dot(a, b):
if not hasattr(a, 'ndim') or not hasattr(b, 'ndim'):
raise NotImplementedError(
"Cannot perform dot product on types %s, %s" %
(type(a), type(b)))
return tensordot(a, b, axes=((a.ndim - 1,), (b.ndim - 2,)))


Expand Down
28 changes: 28 additions & 0 deletions sparse/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def test_tensordot(a_shape, b_shape, axes):


def test_dot():
import operator
a = random_x((3, 4, 5))
b = random_x((5, 6))

Expand All @@ -136,6 +137,33 @@ def test_dot():
assert_eq(a.dot(b), sa.dot(sb))
assert_eq(np.dot(a, b), sparse.dot(sa, sb))

if hasattr(operator, 'matmul'):
# Basic equivalences
assert_eq(eval("a @ b"), eval("sa @ sb"))
assert_eq(eval("sa @ sb"), sparse.dot(sa, sb))

# Test that SOO's and np.array's combine correctly
assert_eq(eval("a @ sb"), eval("sa @ b"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could also check np.dot(sa, sb). If the classes are set up correctly, this should know how to delegate to sparse.dot (though I think it requires __numpy_ufunc__)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets a failure that doesn't seem straightforward to fix. But moreover, I wouldn't really expect it to work as a user.

sparse/core.py:448: in __mul__
    return self.elemwise_binary(operator.mul, other)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <COO: shape=(3, 4, 5), dtype=float64, nnz=6>
func = <built-in function mul>
other = <COO: shape=(5, 6), dtype=float64, nnz=4>, args = (), kwargs = {}

    def elemwise_binary(self, func, other, *args, **kwargs):
        assert isinstance(other, COO)
        if kwargs.pop('check', True) and func(0, 0, *args, **kwargs) != 0:
            raise ValueError("Performing this operation would produce "
                    "a dense result: %s" % str(func))
        if self.shape != other.shape:
>           raise NotImplementedError("Broadcasting is not supported")
E           NotImplementedError: Broadcasting is not supported

sparse/core.py:478: NotImplementedError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird.. that error comes from the fact that np.dot(sa, sb) ends up calling sa.__mul__(sb), which... I have no idea why it would do that.



@pytest.mark.xfail
def test_dot_nocoercion():
a = random_x((3, 4, 5))
b = random_x((5, 6))

la = a.tolist()
lb = b.tolist()
la, lb # silencing flake8

sa = COO.from_numpy(a)
sb = COO.from_numpy(b)
sa, sb # silencing flake8

if hasattr(operator, 'matmul'):
# Operations with naive collection (list)
assert_eq(eval("la @ b"), eval("la @ sb"))
assert_eq(eval("a @ lb"), eval("sa @ lb"))


@pytest.mark.parametrize('func', [np.expm1, np.log1p, np.sin, np.tan,
np.sinh, np.tanh, np.floor, np.ceil,
Expand Down