Skip to content

Commit

Permalink
API: Improve sparse.asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jan 2, 2024
1 parent b12d51d commit 215bca6
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 23 deletions.
1 change: 1 addition & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._version import get_versions

__version__ = get_versions()["version"]
__array_api_version__ = "2022.12"
del get_versions

from numpy import (
Expand Down
32 changes: 16 additions & 16 deletions sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,7 +1819,7 @@ def zeros(shape, dtype=float, format="coo", **kwargs):
array([[0, 0],
[0, 0]])
"""
return full(shape, 0, np.dtype(dtype)).asformat(format, **kwargs)
return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, **kwargs)


def zeros_like(a, dtype=None, shape=None, format=None, **kwargs):
Expand Down Expand Up @@ -1848,7 +1848,7 @@ def zeros_like(a, dtype=None, shape=None, format=None, **kwargs):
array([[0, 0, 0],
[0, 0, 0]])
"""
return full_like(a, 0, dtype=dtype, shape=shape, format=format, **kwargs)
return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, **kwargs)


def ones(shape, dtype=float, format="coo", **kwargs):
Expand Down Expand Up @@ -1880,7 +1880,7 @@ def ones(shape, dtype=float, format="coo", **kwargs):
array([[1, 1],
[1, 1]])
"""
return full(shape, 1, np.dtype(dtype)).asformat(format, **kwargs)
return full(shape, fill_value=1, dtype=np.dtype(dtype), format=format, **kwargs)


def ones_like(a, dtype=None, shape=None, format=None, **kwargs):
Expand Down Expand Up @@ -1909,18 +1909,18 @@ def ones_like(a, dtype=None, shape=None, format=None, **kwargs):
array([[1, 1, 1],
[1, 1, 1]])
"""
return full_like(a, 1, dtype=dtype, shape=shape, format=format, **kwargs)
return full_like(a, fill_value=1, dtype=dtype, shape=shape, format=format, **kwargs)


def empty(shape, dtype=float, format="coo", **kwargs):
return full(shape, 0, np.dtype(dtype)).asformat(format, **kwargs)
return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, **kwargs)

Check warning on line 1916 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L1916

Added line #L1916 was not covered by tests


empty.__doc__ = zeros.__doc__


def empty_like(a, dtype=None, shape=None, format=None, **kwargs):
return full_like(a, 0, dtype=dtype, shape=shape, format=format, **kwargs)
return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, **kwargs)

Check warning on line 1923 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L1923

Added line #L1923 was not covered by tests


empty_like.__doc__ = zeros_like.__doc__
Expand Down Expand Up @@ -2159,16 +2159,16 @@ def asarray(
return obj

elif isinstance(obj, spmatrix):
return format_dict[format].from_scipy_sparse(
obj.astype(dtype=dtype, copy=copy)
)

# check for scalars and 0-D arrays
elif np.isscalar(obj) or (isinstance(obj, np.ndarray) and obj.shape == ()):
return np.asarray(obj, dtype=dtype)

elif isinstance(obj, np.ndarray):
return format_dict[format].from_numpy(obj).astype(dtype=dtype, copy=copy)
sparse_obj = format_dict[format].from_scipy_sparse(obj)
if dtype is None:
dtype = sparse_obj.dtype

Check warning on line 2164 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2164

Added line #L2164 was not covered by tests
return sparse_obj.astype(dtype=dtype, copy=copy)

elif np.isscalar(obj) or isinstance(obj, (np.ndarray, Iterable)):
sparse_obj = format_dict[format].from_numpy(np.asarray(obj))
if dtype is None:
dtype = sparse_obj.dtype

Check warning on line 2170 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2170

Added line #L2170 was not covered by tests
return sparse_obj.astype(dtype=dtype, copy=copy)

else:
raise ValueError(f"{type(obj)} not supported.")
Expand Down
9 changes: 7 additions & 2 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,11 +1145,16 @@ def _arg_minmax_common(
assert mode in ("max", "min")
max_mode_flag = mode == "max"

from .core import COO

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
sparse._coo.core
begins an import cycle.

if not isinstance(x, COO):
raise ValueError(f"Only COO arrays are supported but {type(x)} was passed.")

Check warning on line 1151 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1151

Added line #L1151 was not covered by tests

if not isinstance(axis, (int, type(None))):
raise ValueError(f"axis must be int or None, but it's: {type(axis)}")
raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")

Check warning on line 1154 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1154

Added line #L1154 was not covered by tests
if isinstance(axis, int) and axis >= x.ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {x.ndim}"
f"`axis={axis}` is out of bounds for array of dimension {x.ndim}."
)
if x.ndim == 0:
raise ValueError("Input array must be at least 1-D, but it's 0-D.")
Expand Down
11 changes: 11 additions & 0 deletions sparse/_sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,17 @@ def conj(self):
"""
return np.conj(self)

def __array_namespace__(self, *, api_version=None):
if api_version is None:
api_version = "2022.12"

Check warning on line 992 in sparse/_sparse_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/_sparse_array.py#L991-L992

Added lines #L991 - L992 were not covered by tests

if api_version in ("2021.12", "2022.12"):
import sparse

Check warning on line 995 in sparse/_sparse_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/_sparse_array.py#L994-L995

Added lines #L994 - L995 were not covered by tests

return sparse

Check warning on line 997 in sparse/_sparse_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/_sparse_array.py#L997

Added line #L997 was not covered by tests
else:
raise ValueError(f'"{api_version}" Array API version not supported.')

Check warning on line 999 in sparse/_sparse_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/_sparse_array.py#L999

Added line #L999 was not covered by tests

def __bool__(self):
""" """
return self._to_scalar(bool)
Expand Down
2 changes: 2 additions & 0 deletions sparse/_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ def __init__(self, func, *args, **kwargs):

sparse_args = [arg for arg in args if isinstance(arg, SparseArray)]

if len(sparse_args) == 0:
raise ValueError(f"None of the args is sparse: {args}")
if all(isinstance(arg, DOK) for arg in sparse_args):
out_type = DOK
elif all(isinstance(arg, GCXS) for arg in sparse_args):
Expand Down
7 changes: 4 additions & 3 deletions sparse/tests/test_array_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class TestAsarray:
@pytest.mark.parametrize("dtype", [np.int64, np.float64, np.complex128])
@pytest.mark.parametrize("format", ["dok", "gcxs", "coo"])
def test_asarray(self, input, dtype, format):
if format == "dok" and (np.isscalar(input) or input.ndim == 0):
# scalars and 0-D arrays aren't supported in DOK format
return

s = sparse.asarray(input, dtype=dtype, format=format)

actual = s.todense() if hasattr(s, "todense") else s
Expand All @@ -132,9 +136,6 @@ def test_asarray_special_cases(self):
with pytest.raises(ValueError, match="Taco not yet supported."):
sparse.asarray(self.np_eye, backend="taco")

with pytest.raises(ValueError, match="<class 'list'> not supported."):
sparse.asarray([1, 2, 3])

with pytest.raises(ValueError, match="any backend not supported."):
sparse.asarray(self.np_eye, backend="any")

Expand Down
2 changes: 1 addition & 1 deletion sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,7 +1736,7 @@ def test_argmax_argmin_constraint(func):
s = sparse.COO.from_numpy(np.full((2, 2), 2), fill_value=2)

with pytest.raises(
ValueError, match="axis 2 is out of bounds for array of dimension 2"
ValueError, match="`axis=2` is out of bounds for array of dimension 2."
):
func(s, axis=2)

Expand Down
3 changes: 2 additions & 1 deletion sparse/tests/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ def test_elemwise_noargs():
def func():
return np.float_(5.0)

assert_eq(sparse.elemwise(func), func())
with pytest.raises(ValueError, match=r"None of the args is sparse:"):
sparse.elemwise(func)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 215bca6

Please sign in to comment.