Skip to content

Commit

Permalink
(fix): allow init without X and correct shape inferred (#1841)
Browse files Browse the repository at this point in the history
* (fix): allow init without `X` and correct shape inferred

* (chore): relnote

* (fix): typing

* (chore): more robust size getting

* (fix): `layers` bug

* (refactor): move helpers to bottom
  • Loading branch information
ilan-gold authored Feb 13, 2025
1 parent da04841 commit 4dc0618
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 44 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1941.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow initialization of {class}`anndata.AnnData` objects without `X` (since they could be constructed previously by deleting `X`) {user}`ilan-gold`
81 changes: 52 additions & 29 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import partial
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import h5py
import numpy as np
Expand Down Expand Up @@ -60,33 +60,6 @@
from .index import Index


# for backwards compat
def _find_corresponding_multicol_key(key, keys_multicol):
"""Find the corresponding multicolumn key."""
for mk in keys_multicol:
if key.startswith(mk) and "of" in key:
return mk
return None


# for backwards compat
def _gen_keys_from_multicol_key(key_multicol, n_keys):
"""Generates single-column keys from multicolumn key."""
keys = [f"{key_multicol}{i + 1:03}of{n_keys:03}" for i in range(n_keys)]
return keys


def _check_2d_shape(X):
"""\
Check shape of array or sparse matrix.
Assure that X is always 2D: Unlike numpy we always deal with 2D arrays.
"""
if X.dtype.names is None and len(X.shape) != 2:
msg = f"X needs to be 2-dimensional, not {len(X.shape)}-dimensional."
raise ValueError(msg)


class AnnData(metaclass=utils.DeprecationMixinMeta):
"""\
An annotated data matrix.
Expand Down Expand Up @@ -433,7 +406,11 @@ def _init_as_actual(
source = "X"
else:
self._X = None
n_obs, n_vars = (None, None) if shape is None else shape
n_obs, n_vars = (
shape
if shape is not None
else _infer_shape(obs, var, obsm, varm, layers, obsp, varp)
)
source = "shape"

# annotations
Expand Down Expand Up @@ -2081,3 +2058,49 @@ def _get_and_delete_multicol_field(self, a, key_multicol):
values = getattr(self, a)[keys].values
getattr(self, a).drop(keys, axis=1, inplace=True)
return values


def _check_2d_shape(X):
"""\
Check shape of array or sparse matrix.
Assure that X is always 2D: Unlike numpy we always deal with 2D arrays.
"""
if X.dtype.names is None and len(X.shape) != 2:
msg = f"X needs to be 2-dimensional, not {len(X.shape)}-dimensional."
raise ValueError(msg)


def _infer_shape_for_axis(
xxx: pd.DataFrame | Mapping[str, Iterable[Any]] | None,
xxxm: np.ndarray | Mapping[str, Sequence[Any]] | None,
layers: Mapping[str, np.ndarray | sparse.spmatrix] | None,
xxxp: np.ndarray | Mapping[str, Sequence[Any]] | None,
axis: Literal[0, 1],
) -> int | None:
for elem in [xxx, xxxm, xxxp]:
if elem is not None and hasattr(elem, "shape"):
return elem.shape[0]
for elem, id in zip([layers, xxxm, xxxp], ["layers", "xxxm", "xxxp"]):
if elem is not None:
elem = cast(Mapping, elem)
for sub_elem in elem.values():
if hasattr(sub_elem, "shape"):
size = cast(int, sub_elem.shape[axis if id == "layers" else 0])
return size
return None


def _infer_shape(
obs: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
var: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
obsm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
varm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
layers: Mapping[str, np.ndarray | sparse.spmatrix] | None = None,
obsp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
):
return (
_infer_shape_for_axis(obs, obsm, layers, obsp, 0),
_infer_shape_for_axis(var, varm, layers, varp, 1),
)
21 changes: 21 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import warnings
from itertools import product
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
Expand All @@ -11,11 +12,16 @@
from scipy import sparse as sp
from scipy.sparse import csr_matrix, issparse

import anndata as ad
from anndata import AnnData, ImplicitModificationWarning
from anndata._settings import settings
from anndata.compat import CAN_USE_SPARSE_ARRAY
from anndata.tests.helpers import assert_equal, gen_adata, get_multiindex_columns_df

if TYPE_CHECKING:
from pathlib import Path
from typing import Literal

# some test objects that we use below
adata_dense = AnnData(np.array([[1, 2], [3, 4]]))
adata_dense.layers["test"] = adata_dense.X
Expand Down Expand Up @@ -735,3 +741,18 @@ def test_to_memory_no_copy():
assert mem.obsp[key] is adata.obsp[key]
for key in adata.varp:
assert mem.varp[key] is adata.varp[key]


@pytest.mark.parametrize("axis", ["obs", "var"])
@pytest.mark.parametrize("elem_type", ["p", "m"])
def test_create_adata_from_single_axis_elem(
axis: Literal["obs", "var"], elem_type: Literal["m", "p"], tmp_path: Path
):
d = dict(
a=np.zeros((10, 10)),
)
in_memory = AnnData(**{f"{axis}{elem_type}": d})
assert in_memory.shape == (10, 0) if axis == "obs" else (0, 10)
in_memory.write_h5ad(tmp_path / "adata.h5ad")
from_disk = ad.read_h5ad(tmp_path / "adata.h5ad")
ad.tests.helpers.assert_equal(from_disk, in_memory)
36 changes: 21 additions & 15 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,28 @@
from anndata.io import read_loom
from anndata.tests.helpers import gen_typed_df_t2_size

X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
L = np.array([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
X_ = np.arange(12).reshape((3, 4))
L = np.arange(12).reshape((3, 4)) + 12


def test_creation():
@pytest.fixture(params=[X_, None])
def X(request):
return request.param


def test_creation(X: np.ndarray | None):
adata = AnnData(X=X, layers=dict(L=L.copy()))

assert list(adata.layers.keys()) == ["L"]
assert "L" in adata.layers
assert "X" not in adata.layers
assert "some_other_thing" not in adata.layers
assert (adata.layers["L"] == L).all()
assert adata.shape == L.shape


def test_views():
adata = AnnData(X=X, layers=dict(L=L.copy()))
adata = AnnData(X=X_, layers=dict(L=L.copy()))
adata_view = adata[1:, 1:]

assert adata_view.layers.is_view
Expand All @@ -36,13 +42,13 @@ def test_views():
assert adata_view.layers.keys() == adata.layers.keys()
assert (adata_view.layers["L"] == adata.layers["L"][1:, 1:]).all()

adata.layers["S"] = X
adata.layers["S"] = X_

assert adata_view.layers.keys() == adata.layers.keys()
assert (adata_view.layers["S"] == adata.layers["S"][1:, 1:]).all()

with pytest.warns(ImplicitModificationWarning):
adata_view.layers["T"] = X[1:, 1:]
adata_view.layers["T"] = X_[1:, 1:]

assert not adata_view.layers.is_view
assert not adata_view.is_view
Expand All @@ -51,12 +57,12 @@ def test_views():
@pytest.mark.parametrize(
("df", "homogenous", "dtype"),
[
(lambda: gen_typed_df_t2_size(*X.shape), True, np.object_),
(lambda: pd.DataFrame(X**2), False, np.int_),
(lambda: gen_typed_df_t2_size(*X_.shape), True, np.object_),
(lambda: pd.DataFrame(X_**2), False, np.int_),
],
)
def test_set_dataframe(homogenous, df, dtype):
adata = AnnData(X)
adata = AnnData(X_)
if homogenous:
with pytest.warns(UserWarning, match=r"Layer 'df'.*dtype object"):
adata.layers["df"] = df()
Expand All @@ -68,7 +74,7 @@ def test_set_dataframe(homogenous, df, dtype):
assert np.issubdtype(adata.layers["df"].dtype, dtype)


def test_readwrite(backing_h5ad):
def test_readwrite(X: np.ndarray | None, backing_h5ad):
adata = AnnData(X=X, layers=dict(L=L.copy()))
adata.write(backing_h5ad)
adata_read = read_h5ad(backing_h5ad)
Expand All @@ -80,7 +86,7 @@ def test_readwrite(backing_h5ad):
@pytest.mark.skipif(find_spec("loompy") is None, reason="loompy not installed")
def test_readwrite_loom(tmp_path):
loom_path = tmp_path / "test.loom"
adata = AnnData(X=X, layers=dict(L=L.copy()))
adata = AnnData(X=X_, layers=dict(L=L.copy()))

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
Expand All @@ -104,7 +110,7 @@ def test_backed():


def test_copy():
adata = AnnData(X=X, layers=dict(L=L.copy()))
adata = AnnData(X=X_, layers=dict(L=L.copy()))
bdata = adata.copy()
# check that we don’t create too many references
assert bdata._layers is bdata.layers._data
Expand All @@ -114,13 +120,13 @@ def test_copy():


def test_shape_error():
adata = AnnData(X=X)
adata = AnnData(X=X_)
with pytest.raises(
ValueError,
match=(
r"Value passed for key 'L' is of incorrect shape\. "
r"Values of layers must match dimensions \('obs', 'var'\) of parent\. "
r"Value had shape \(4, 3\) while it should have had \(3, 3\)\."
r"Value had shape \(4, 4\) while it should have had \(3, 4\)\."
),
):
adata.layers["L"] = np.zeros((X.shape[0] + 1, X.shape[1]))
adata.layers["L"] = np.zeros((X_.shape[0] + 1, X_.shape[1]))

0 comments on commit 4dc0618

Please sign in to comment.