Skip to content

Commit

Permalink
Backport PR #1616: (fix): correct typing of AnnData.X
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold authored and flying-sheep committed Aug 30, 2024
1 parent e184c55 commit 2514dce
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 85 deletions.
8 changes: 4 additions & 4 deletions docs/concatenation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ When the variables present in the objects to be concatenated aren't exactly the
This is otherwise called taking the `"inner"` (intersection) or `"outer"` (union) join.
For example, given two anndata objects with differing variables:

>>> a = AnnData(sparse.eye(3), var=pd.DataFrame(index=list("abc")))
>>> b = AnnData(sparse.eye(2), var=pd.DataFrame(index=list("ba")))
>>> a = AnnData(sparse.eye(3, format="csr"), var=pd.DataFrame(index=list("abc")))
>>> b = AnnData(sparse.eye(2, format="csr"), var=pd.DataFrame(index=list("ba")))
>>> ad.concat([a, b], join="inner").X.toarray()
array([[1., 0.],
[0., 1.],
Expand Down Expand Up @@ -208,11 +208,11 @@ Note that comparisons are made after indices are aligned.
That is, if the objects only share a subset of indices on the alternative axis, it's only required that values for those indices match when using a strategy like `"same"`.

>>> a = AnnData(
... sparse.eye(3),
... sparse.eye(3, format="csr"),
... var=pd.DataFrame({"nums": [1, 2, 3]}, index=list("abc"))
... )
>>> b = AnnData(
... sparse.eye(2),
... sparse.eye(2, format="csr"),
... var=pd.DataFrame({"nums": [2, 1]}, index=list("ba"))
... )
>>> ad.concat([a, b], merge="same").var
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/1616.doc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Correct {attr}`anndata.AnnData.X` type to include {class}`~anndata.experimental.CSRDataset` and {class}`~anndata.experimental.CSCDataset` as possible types {user}`ilan-gold`
16 changes: 16 additions & 0 deletions hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,19 @@ dependencies = ["setuptools"] # https://bitbucket.org/pybtex-devs/pybtex/issues
[envs.docs.scripts]
build = "sphinx-build -M html docs docs/_build -W --keep-going {args}"
clean = "git clean -fX -- docs"

[envs.hatch-test]
default-args = []
extra-dependencies = ["ipykernel"]
features = ["dev", "test"]
overrides.matrix.deps.env-vars = [
{ key = "UV_PRERELEASE", value = "allow", if = ["pre"] },
{ key = "UV_RESOLUTION", value = "lowest-direct", if = ["min"] },
]
overrides.matrix.deps.python = [
{ if = ["min"], value = "3.9" },
{ if = ["stable", "pre"], value = "3.12" },
]

[[envs.hatch-test.matrix]]
deps = ["stable", "pre", "min"]
4 changes: 2 additions & 2 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
from os import PathLike
from typing import Any, Literal

from .._types import ArrayDataStructureType
from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
from .index import Index, Index1D
from .views import ArrayView


# for backwards compat
Expand Down Expand Up @@ -543,7 +543,7 @@ def shape(self) -> tuple[int, int]:
return self.n_obs, self.n_vars

@property
def X(self) -> np.ndarray | sparse.spmatrix | ArrayView | None:
def X(self) -> ArrayDataStructureType | None:
"""Data matrix of shape :attr:`n_obs` × :attr:`n_vars`."""
if self.isbacked:
if not self.file.is_open:
Expand Down
4 changes: 2 additions & 2 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ def gen_reindexer(new_var: pd.Index, cur_var: pd.Index):
Usage
-----
>>> a = AnnData(sparse.eye(3), var=pd.DataFrame(index=list("abc")))
>>> b = AnnData(sparse.eye(2), var=pd.DataFrame(index=list("ba")))
>>> a = AnnData(sparse.eye(3, format="csr"), var=pd.DataFrame(index=list("abc")))
>>> b = AnnData(sparse.eye(2, format="csr"), var=pd.DataFrame(index=list("ba")))
>>> reindexer = gen_reindexer(a.var_names, b.var_names)
>>> sparse.vstack([a.X, reindexer(b.X)]).toarray()
array([[1., 0., 0.],
Expand Down
71 changes: 29 additions & 42 deletions src/anndata/_core/storage.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import warnings
from enum import Enum
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union, get_args

import numpy as np
import pandas as pd
Expand All @@ -24,47 +23,26 @@
join_english,
raise_value_error_if_multiindex_columns,
)
from .sparse_dataset import BaseCompressedSparseDataset
from .sparse_dataset import CSCDataset, CSRDataset

if TYPE_CHECKING:
from collections.abc import Generator
from typing import Any
from typing import Any, TypeAlias


class ArrayDataStructureType(Enum):
# Memory
Array = (np.ndarray, "np.ndarray")
Masked = (ma.MaskedArray, "numpy.ma.core.MaskedArray")
Sparse = (sparse.spmatrix, "scipy.sparse.spmatrix")
AwkArray = (AwkArray, "awkward.Array")
# Backed
HDF5Dataset = (H5Array, "h5py.Dataset")
ZarrArray = (ZarrArray, "zarr.Array")
ZappyArray = (ZappyArray, "zappy.base.ZappyArray")
BackedSparseMatrix = (
BaseCompressedSparseDataset,
"anndata.experimental.[CSC,CSR]Dataset",
)
# Distributed
DaskArray = (DaskArray, "dask.array.Array")
CupyArray = (CupyArray, "cupy.ndarray")
CupySparseMatrix = (CupySparseMatrix, "cupyx.scipy.sparse.spmatrix")

@property
def cls(self):
return self.value[0]

@property
def qualname(self):
return self.value[1]

@classmethod
def classes(cls) -> tuple[type, ...]:
return tuple(v.cls for v in cls)

@classmethod
def qualnames(cls) -> Generator[str, None, None]:
yield from (v.qualname for v in cls)
ArrayDataStructureType: TypeAlias = Union[
np.ndarray,
ma.MaskedArray,
sparse.csr_matrix,
sparse.csc_matrix,
AwkArray,
H5Array,
ZarrArray,
ZappyArray,
CSRDataset,
CSCDataset,
DaskArray,
CupyArray,
CupySparseMatrix,
]


def coerce_array(
Expand All @@ -79,12 +57,21 @@ def coerce_array(
if allow_array_like and np.isscalar(value):
return value
# If value is one of the allowed types, return it
if isinstance(value, ArrayDataStructureType.classes()):
array_data_structure_types = get_args(ArrayDataStructureType)
if isinstance(value, array_data_structure_types):
if isinstance(value, np.matrix):
msg = f"{name} should not be a np.matrix, use np.ndarray instead."
warnings.warn(msg, ImplicitModificationWarning)
value = value.A
return value
elif isinstance(value, sparse.spmatrix):
msg = (
f"AnnData previously had undefined behavior around matrices of type {type(value)}."
"In 0.12, passing in this type will throw an error. Please convert to a supported type."
"Continue using for this minor version at your own risk."
)
warnings.warn(msg, FutureWarning)
return value
if isinstance(value, pd.DataFrame):
if allow_df:
raise_value_error_if_multiindex_columns(value, name)
Expand All @@ -98,7 +85,7 @@ def coerce_array(
except (ValueError, TypeError) as _e:
e = _e
# if value isn’t the right type or convertible, raise an error
msg = f"{name} needs to be of one of {join_english(ArrayDataStructureType.qualnames())}, not {type(value)}."
msg = f"{name} needs to be of one of {join_english(map(str, array_data_structure_types))}, not {type(value)}."
if e is not None:
msg += " (Failed to convert it to an array, see above for details.)"
raise ValueError(msg) from e
27 changes: 3 additions & 24 deletions src/anndata/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,12 @@

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from scipy import sparse

from anndata._core.anndata import AnnData

from ._core.sparse_dataset import BaseCompressedSparseDataset
from ._core.anndata import AnnData
from ._core.storage import ArrayDataStructureType
from .compat import (
AwkArray,
CupyArray,
CupySparseMatrix,
DaskArray,
H5Array,
H5Group,
ZappyArray,
ZarrArray,
ZarrGroup,
)
Expand All @@ -40,20 +32,7 @@
]

InMemoryArrayOrScalarType: TypeAlias = Union[
NDArray,
np.ma.MaskedArray,
sparse.spmatrix,
H5Array,
ZarrArray,
ZappyArray,
BaseCompressedSparseDataset,
DaskArray,
CupyArray,
CupySparseMatrix,
AwkArray,
pd.DataFrame,
np.number,
str,
pd.DataFrame, np.number, str, ArrayDataStructureType
]
RWAble: TypeAlias = Union[
InMemoryArrayOrScalarType, dict[str, "RWAble"], list["RWAble"]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_creation():
AnnData(np.array([[1, 2], [3, 4]]))
AnnData(np.array([[1, 2], [3, 4]]), {}, {})
AnnData(ma.array([[1, 2], [3, 4]]), uns=dict(mask=[0, 1, 1, 0]))
AnnData(sp.eye(2))
AnnData(sp.eye(2, format="csr"))
X = np.array([[1, 2, 3], [4, 5, 6]])
adata = AnnData(
X=X,
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_creation_error(src, src_arg, dim_msg, dim, dim_arg, msg: str | None):
def test_invalid_X():
with pytest.raises(
ValueError,
match=r"X needs to be of one of np\.ndarray.*not <class 'str'>\.",
match=r"X needs to be of one of <class 'numpy.ndarray'>.*not <class 'str'>\.",
):
AnnData("string is not a valid X")

Expand Down Expand Up @@ -122,7 +122,7 @@ def test_error_create_from_multiindex_df(attr):


def test_create_from_sparse_df():
s = sp.random(20, 30, density=0.2)
s = sp.random(20, 30, density=0.2, format="csr")
obs_names = [f"obs{i}" for i in range(20)]
var_names = [f"var{i}" for i in range(30)]
df = pd.DataFrame.sparse.from_spmatrix(s, index=obs_names, columns=var_names)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_obsmvarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,21 @@ def test_setting_dataframe(adata: AnnData):


def test_setting_sparse(adata: AnnData):
obsm_sparse = sparse.random(M, 100)
obsm_sparse = sparse.random(M, 100, format="csr")
adata.obsm["a"] = obsm_sparse
assert not np.any((adata.obsm["a"] != obsm_sparse).data)

varm_sparse = sparse.random(N, 100)
varm_sparse = sparse.random(N, 100, format="csr")
adata.varm["a"] = varm_sparse
assert not np.any((adata.varm["a"] != varm_sparse).data)

h = joblib.hash(adata)

bad_obsm_sparse = sparse.random(M * 2, M)
bad_obsm_sparse = sparse.random(M * 2, M, format="csr")
with pytest.raises(ValueError, match=r"incorrect shape"):
adata.obsm["b"] = bad_obsm_sparse

bad_varm_sparse = sparse.random(N * 2, N)
bad_varm_sparse = sparse.random(N * 2, N, format="csr")
with pytest.raises(ValueError, match=r"incorrect shape"):
adata.varm["b"] = bad_varm_sparse

Expand Down
8 changes: 4 additions & 4 deletions tests/test_obspvarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,21 @@ def test_setting_ndarray(adata: AnnData):


def test_setting_sparse(adata: AnnData):
obsp_sparse = sparse.random(M, M)
obsp_sparse = sparse.random(M, M, format="csr")
adata.obsp["a"] = obsp_sparse
assert not np.any((adata.obsp["a"] != obsp_sparse).data)

varp_sparse = sparse.random(N, N)
varp_sparse = sparse.random(N, N, format="csr")
adata.varp["a"] = varp_sparse
assert not np.any((adata.varp["a"] != varp_sparse).data)

h = joblib.hash(adata)

bad_obsp_sparse = sparse.random(M * 2, M)
bad_obsp_sparse = sparse.random(M * 2, M, format="csr")
with pytest.raises(ValueError, match=r"incorrect shape"):
adata.obsp["b"] = bad_obsp_sparse

bad_varp_sparse = sparse.random(N * 2, N)
bad_varp_sparse = sparse.random(N * 2, N, format="csr")
with pytest.raises(ValueError, match=r"incorrect shape"):
adata.varp["b"] = bad_varp_sparse

Expand Down
9 changes: 9 additions & 0 deletions tests/test_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,12 @@ def test_set_dense_x_view_from_sparse():
assert_equal(view.X, x1[:30])
assert_equal(orig.X[:30], x1[:30]) # change propagates through
assert_equal(orig.X[30:], x[30:]) # change propagates through


def test_warn_on_non_csr_csc_matrix():
X = sparse.eye(100)
with pytest.warns(
FutureWarning,
match=rf"AnnData previously had undefined behavior around matrices of type {type(X)}.*",
):
ad.AnnData(X=X)

0 comments on commit 2514dce

Please sign in to comment.