Skip to content

Commit

Permalink
Add TestCleanup
Browse files Browse the repository at this point in the history
Adds test class to unit test `scvelo/core/_anndata.py::cleanup`.
  • Loading branch information
WeilerP committed Jul 22, 2021
1 parent 6581307 commit 455855b
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions tests/core/test_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from scvelo.core import (
clean_obs_names,
cleanup,
get_modality,
make_dense,
make_sparse,
Expand Down Expand Up @@ -63,6 +64,98 @@ def test_different_obs_id_length(self, obs_names, obs_names_cleaned):
assert adata.obs["sample_batch"].str.startswith("sample").all()


class TestCleanup(TestBase):
@given(adata=get_adata(), copy=st.booleans())
def test_cleanup_all(self, adata: AnnData, copy: bool):
returned_adata = cleanup(adata, clean="all", copy=copy)

if copy:
assert isinstance(returned_adata, AnnData)
adata = returned_adata
else:
assert returned_adata is None

assert len(adata.layers) == 0
assert len(adata.uns) == 0
assert len(adata.obs.columns) == 0
assert len(adata.var.columns) == 0

@given(adata=get_adata(), copy=st.booleans())
def test_cleanup_default_clean_w_random_adata(self, adata: AnnData, copy: bool):
n_obs_cols = len(adata.obs.columns)
n_var_cols = len(adata.var.columns)
n_uns_slots = len(adata.uns)

returned_adata = cleanup(adata)
assert returned_adata is None

assert len(adata.layers) == 0
assert len(adata.uns) == n_uns_slots
assert len(adata.obs.columns) == n_obs_cols
assert len(adata.var.columns) == n_var_cols

@given(
adata=get_adata(layer_keys=["unspliced", "spliced", "Ms", "Mu", "random"]),
copy=st.booleans(),
)
def test_cleanup_default_clean(self, adata: AnnData, copy: bool):
n_obs_cols = len(adata.obs.columns)
n_var_cols = len(adata.var.columns)
n_uns_slots = len(adata.uns)

returned_adata = cleanup(adata)
assert returned_adata is None

assert len(adata.layers) == 4
assert len(adata.uns) == n_uns_slots
assert len(adata.obs.columns) == n_obs_cols
assert len(adata.var.columns) == n_var_cols

@given(
adata=get_adata(),
copy=st.booleans(),
n_modalities=st.integers(min_value=0),
n_cols=st.integers(min_value=0),
)
def test_cleanup_some(
self, adata: AnnData, copy: bool, n_modalities: int, n_cols: int
):
layers_to_keep = self._subset_modalities(
adata,
n_modalities,
from_obsm=False,
)
obs_cols_to_keep = self._subset_columns(adata, n_cols=n_cols, from_var=False)
var_cols_to_keep = self._subset_columns(adata, n_cols=n_cols, from_obs=False)

# Update in case adata.layers, adata.obs, adata.var share same keys
layers_to_keep += set(adata.layers).intersection(obs_cols_to_keep)
layers_to_keep += set(adata.layers).intersection(var_cols_to_keep)

obs_cols_to_keep += set(adata.obs.columns).intersection(var_cols_to_keep)
obs_cols_to_keep += set(adata.obs.columns).intersection(layers_to_keep)

var_cols_to_keep += set(adata.var.columns).intersection(obs_cols_to_keep)
obs_cols_to_keep += set(adata.var.columns).intersection(layers_to_keep)

returned_adata = cleanup(
adata,
keep=layers_to_keep + obs_cols_to_keep + var_cols_to_keep,
clean="all",
copy=copy,
)

if copy:
assert isinstance(returned_adata, AnnData)
adata = returned_adata
else:
assert returned_adata is None

assert set(adata.layers.keys()) == set(layers_to_keep).difference({"X"})
assert set(adata.obs.columns) == set(obs_cols_to_keep)
assert set(adata.var.columns) == set(var_cols_to_keep)


class TestGetModality(TestBase):
@given(adata=get_adata())
def test_get_modality(self, adata: AnnData):
Expand Down

0 comments on commit 455855b

Please sign in to comment.