Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Improve sparse.asarray #615

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._version import get_versions

__version__ = get_versions()["version"]
__array_api_version__ = "2022.12"
del get_versions

from numpy import (
Expand Down
32 changes: 16 additions & 16 deletions sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,7 +1819,7 @@
array([[0, 0],
[0, 0]])
"""
return full(shape, 0, np.dtype(dtype)).asformat(format, **kwargs)
return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, **kwargs)


def zeros_like(a, dtype=None, shape=None, format=None, **kwargs):
Expand Down Expand Up @@ -1848,7 +1848,7 @@
array([[0, 0, 0],
[0, 0, 0]])
"""
return full_like(a, 0, dtype=dtype, shape=shape, format=format, **kwargs)
return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, **kwargs)


def ones(shape, dtype=float, format="coo", **kwargs):
Expand Down Expand Up @@ -1880,7 +1880,7 @@
array([[1, 1],
[1, 1]])
"""
return full(shape, 1, np.dtype(dtype)).asformat(format, **kwargs)
return full(shape, fill_value=1, dtype=np.dtype(dtype), format=format, **kwargs)


def ones_like(a, dtype=None, shape=None, format=None, **kwargs):
Expand Down Expand Up @@ -1909,18 +1909,18 @@
array([[1, 1, 1],
[1, 1, 1]])
"""
return full_like(a, 1, dtype=dtype, shape=shape, format=format, **kwargs)
return full_like(a, fill_value=1, dtype=dtype, shape=shape, format=format, **kwargs)


def empty(shape, dtype=float, format="coo", **kwargs):
return full(shape, 0, np.dtype(dtype)).asformat(format, **kwargs)
return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, **kwargs)

Check warning on line 1916 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L1916

Added line #L1916 was not covered by tests


empty.__doc__ = zeros.__doc__


def empty_like(a, dtype=None, shape=None, format=None, **kwargs):
return full_like(a, 0, dtype=dtype, shape=shape, format=format, **kwargs)
return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, **kwargs)

Check warning on line 1923 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L1923

Added line #L1923 was not covered by tests


empty_like.__doc__ = zeros_like.__doc__
Expand Down Expand Up @@ -2159,16 +2159,16 @@
return obj

elif isinstance(obj, spmatrix):
return format_dict[format].from_scipy_sparse(
obj.astype(dtype=dtype, copy=copy)
)

# check for scalars and 0-D arrays
elif np.isscalar(obj) or (isinstance(obj, np.ndarray) and obj.shape == ()):
return np.asarray(obj, dtype=dtype)

elif isinstance(obj, np.ndarray):
return format_dict[format].from_numpy(obj).astype(dtype=dtype, copy=copy)
sparse_obj = format_dict[format].from_scipy_sparse(obj)
if dtype is None:
dtype = sparse_obj.dtype

Check warning on line 2164 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2164

Added line #L2164 was not covered by tests
return sparse_obj.astype(dtype=dtype, copy=copy)

elif np.isscalar(obj) or isinstance(obj, (np.ndarray, Iterable)):
sparse_obj = format_dict[format].from_numpy(np.asarray(obj))
if dtype is None:
dtype = sparse_obj.dtype

Check warning on line 2170 in sparse/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_common.py#L2170

Added line #L2170 was not covered by tests
return sparse_obj.astype(dtype=dtype, copy=copy)

else:
raise ValueError(f"{type(obj)} not supported.")
Expand Down
9 changes: 7 additions & 2 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,11 +1145,16 @@
assert mode in ("max", "min")
max_mode_flag = mode == "max"

from .core import COO
Dismissed Show dismissed Hide dismissed

if not isinstance(x, COO):
raise ValueError(f"Only COO arrays are supported but {type(x)} was passed.")

Check warning on line 1151 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1151

Added line #L1151 was not covered by tests

if not isinstance(axis, (int, type(None))):
raise ValueError(f"axis must be int or None, but it's: {type(axis)}")
raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")

Check warning on line 1154 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1154

Added line #L1154 was not covered by tests
if isinstance(axis, int) and axis >= x.ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {x.ndim}"
f"`axis={axis}` is out of bounds for array of dimension {x.ndim}."
)
if x.ndim == 0:
raise ValueError("Input array must be at least 1-D, but it's 0-D.")
Expand Down
11 changes: 11 additions & 0 deletions sparse/_sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,17 @@
"""
return np.conj(self)

def __array_namespace__(self, *, api_version=None):
if api_version is None:
api_version = "2022.12"

Check warning on line 992 in sparse/_sparse_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/_sparse_array.py#L991-L992

Added lines #L991 - L992 were not covered by tests

if api_version in ("2021.12", "2022.12"):
import sparse

Check warning on line 995 in sparse/_sparse_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/_sparse_array.py#L994-L995

Added lines #L994 - L995 were not covered by tests

return sparse

Check warning on line 997 in sparse/_sparse_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/_sparse_array.py#L997

Added line #L997 was not covered by tests
else:
raise ValueError(f'"{api_version}" Array API version not supported.')

Check warning on line 999 in sparse/_sparse_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/_sparse_array.py#L999

Added line #L999 was not covered by tests

def __bool__(self):
""" """
return self._to_scalar(bool)
Expand Down
2 changes: 2 additions & 0 deletions sparse/_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ def __init__(self, func, *args, **kwargs):

sparse_args = [arg for arg in args if isinstance(arg, SparseArray)]

if len(sparse_args) == 0:
raise ValueError(f"None of the args is sparse: {args}")
if all(isinstance(arg, DOK) for arg in sparse_args):
out_type = DOK
elif all(isinstance(arg, GCXS) for arg in sparse_args):
Expand Down
7 changes: 4 additions & 3 deletions sparse/tests/test_array_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class TestAsarray:
@pytest.mark.parametrize("dtype", [np.int64, np.float64, np.complex128])
@pytest.mark.parametrize("format", ["dok", "gcxs", "coo"])
def test_asarray(self, input, dtype, format):
if format == "dok" and (np.isscalar(input) or input.ndim == 0):
# scalars and 0-D arrays aren't supported in DOK format
return

s = sparse.asarray(input, dtype=dtype, format=format)

actual = s.todense() if hasattr(s, "todense") else s
Expand All @@ -132,9 +136,6 @@ def test_asarray_special_cases(self):
with pytest.raises(ValueError, match="Taco not yet supported."):
sparse.asarray(self.np_eye, backend="taco")

with pytest.raises(ValueError, match="<class 'list'> not supported."):
sparse.asarray([1, 2, 3])

with pytest.raises(ValueError, match="any backend not supported."):
sparse.asarray(self.np_eye, backend="any")

Expand Down
2 changes: 1 addition & 1 deletion sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,7 +1736,7 @@ def test_argmax_argmin_constraint(func):
s = sparse.COO.from_numpy(np.full((2, 2), 2), fill_value=2)

with pytest.raises(
ValueError, match="axis 2 is out of bounds for array of dimension 2"
ValueError, match="`axis=2` is out of bounds for array of dimension 2."
):
func(s, axis=2)

Expand Down
3 changes: 2 additions & 1 deletion sparse/tests/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ def test_elemwise_noargs():
def func():
return np.float_(5.0)

assert_eq(sparse.elemwise(func), func())
with pytest.raises(ValueError, match=r"None of the args is sparse:"):
sparse.elemwise(func)


@pytest.mark.parametrize(
Expand Down
Loading