Skip to content

Commit

Permalink
Backport PR #1806: Fix backed sparse matrix compat with scipy 1.15
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored and meeseeksmachine committed Dec 19, 2024
1 parent 56a99fd commit e72b707
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1806.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add {mod}`scipy` 1.5 compatibility {user}`flying-sheep`
22 changes: 18 additions & 4 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

import h5py
import numpy as np
import scipy
import scipy.sparse as ss
from packaging.version import Version
from scipy.sparse import _sparsetools

from .. import abc
Expand All @@ -39,11 +41,14 @@

from .._types import GroupStorageType
from ..compat import H5Array
from .index import Index
from .index import Index, Index1D
else:
from scipy.sparse import spmatrix as _cs_matrix


SCIPY_1_15 = Version(scipy.__version__) >= Version("1.15rc0")


class BackedFormat(NamedTuple):
format: Literal["csr", "csc"]
backed_type: type[BackedSparseMatrix]
Expand Down Expand Up @@ -353,7 +358,9 @@ def _get_group_format(group: GroupStorageType) -> str:


# Check for the overridden few methods above in our BackedSparseMatrix subclasses
def is_sparse_indexing_overridden(format: Literal["csr", "csc"], row, col):
def is_sparse_indexing_overridden(
format: Literal["csr", "csc"], row: Index1D, col: Index1D
):
major_indexer, minor_indexer = (row, col) if format == "csr" else (col, row)
return isinstance(minor_indexer, slice) and (
(isinstance(major_indexer, int | np.integer))
Expand All @@ -362,6 +369,13 @@ def is_sparse_indexing_overridden(format: Literal["csr", "csc"], row, col):
)


def validate_indices(
mtx: BackedSparseMatrix, indices: tuple[Index1D, Index1D]
) -> tuple[Index1D, Index1D]:
res = mtx._validate_indices(indices)
return res[0] if SCIPY_1_15 else res


class BaseCompressedSparseDataset(abc._AbstractCSDataset, ABC):
_group: GroupStorageType

Expand Down Expand Up @@ -424,8 +438,8 @@ def __getitem__(
indices = self._normalize_index(index)
row, col = indices
mtx = self._to_backed()
row_sp_matrix_validated, col_sp_matrix_validated = mtx._validate_indices(
(row, col)
row_sp_matrix_validated, col_sp_matrix_validated = validate_indices(
mtx, indices
)

# Handle masked indexing along major axis
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Empty:
pass


Index1D = slice | int | str | np.int64 | np.ndarray
Index1D = slice | int | str | np.int64 | np.ndarray | pd.Series
IndexRest = Index1D | EllipsisType
Index = (
IndexRest
Expand Down
17 changes: 14 additions & 3 deletions tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,17 +615,28 @@ def test_backed_sizeof(
assert csr_mem.__sizeof__() > csc_disk.__sizeof__()


sparray_scipy_bug_marks = (
[pytest.mark.skip(reason="scipy bug causes view to be allocated")]
if CAN_USE_SPARSE_ARRAY
else []
)


@pytest.mark.parametrize(
"group_fn",
[
pytest.param(lambda _: zarr.group(), id="zarr"),
pytest.param(lambda p: h5py.File(p / "test.h5", mode="a"), id="h5py"),
],
)
@pytest.mark.parametrize("sparse_class", [sparse.csr_matrix, sparse.csr_array])
@pytest.mark.parametrize(
"sparse_class",
[
sparse.csr_matrix,
pytest.param(sparse.csr_array, marks=[*sparray_scipy_bug_marks]),
],
)
def test_append_overflow_check(group_fn, sparse_class, tmpdir):
if CAN_USE_SPARSE_ARRAY and issubclass(sparse_class, SpArray):
pytest.skip("scipy bug causes view to be allocated")
group = group_fn(tmpdir)
typemax_int32 = np.iinfo(np.int32).max
orig_mtx = sparse_class(np.ones((1, 1), dtype=bool))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,7 @@ def test_concat_outer_aligned_mapping(elem):
del b.obsm[elem]

concated = concat({"a": a, "b": b}, join="outer", label="group")
result = concated.obsm[elem][concated.obs["group"] == "b"]
result = concated[concated.obs["group"] == "b"].obsm[elem]

check_filled_like(result, elem_name=f"obsm/{elem}")

Expand Down

0 comments on commit e72b707

Please sign in to comment.