diff --git a/tests/core/test_anndata.py b/tests/core/test_anndata.py index f15d329a..62d28a68 100644 --- a/tests/core/test_anndata.py +++ b/tests/core/test_anndata.py @@ -10,6 +10,7 @@ from scvelo.core import ( clean_obs_names, + cleanup, get_modality, make_dense, make_sparse, @@ -63,6 +64,98 @@ def test_different_obs_id_length(self, obs_names, obs_names_cleaned): assert adata.obs["sample_batch"].str.startswith("sample").all() +class TestCleanup(TestBase): + @given(adata=get_adata(), copy=st.booleans()) + def test_cleanup_all(self, adata: AnnData, copy: bool): + returned_adata = cleanup(adata, clean="all", copy=copy) + + if copy: + assert isinstance(returned_adata, AnnData) + adata = returned_adata + else: + assert returned_adata is None + + assert len(adata.layers) == 0 + assert len(adata.uns) == 0 + assert len(adata.obs.columns) == 0 + assert len(adata.var.columns) == 0 + + @given(adata=get_adata(), copy=st.booleans()) + def test_cleanup_default_clean_w_random_adata(self, adata: AnnData, copy: bool): + n_obs_cols = len(adata.obs.columns) + n_var_cols = len(adata.var.columns) + n_uns_slots = len(adata.uns) + + returned_adata = cleanup(adata) + assert returned_adata is None + + assert len(adata.layers) == 0 + assert len(adata.uns) == n_uns_slots + assert len(adata.obs.columns) == n_obs_cols + assert len(adata.var.columns) == n_var_cols + + @given( + adata=get_adata(layer_keys=["unspliced", "spliced", "Ms", "Mu", "random"]), + copy=st.booleans(), + ) + def test_cleanup_default_clean(self, adata: AnnData, copy: bool): + n_obs_cols = len(adata.obs.columns) + n_var_cols = len(adata.var.columns) + n_uns_slots = len(adata.uns) + + returned_adata = cleanup(adata) + assert returned_adata is None + + assert len(adata.layers) == 4 + assert len(adata.uns) == n_uns_slots + assert len(adata.obs.columns) == n_obs_cols + assert len(adata.var.columns) == n_var_cols + + @given( + adata=get_adata(), + copy=st.booleans(), + n_modalities=st.integers(min_value=0), + n_cols=st.integers(min_value=0), + ) + def test_cleanup_some( + self, adata: AnnData, copy: bool, n_modalities: int, n_cols: int + ): + layers_to_keep = self._subset_modalities( + adata, + n_modalities, + from_obsm=False, + ) + obs_cols_to_keep = self._subset_columns(adata, n_cols=n_cols, from_var=False) + var_cols_to_keep = self._subset_columns(adata, n_cols=n_cols, from_obs=False) + + # Update in case adata.layers, adata.obs, adata.var share same keys + layers_to_keep += set(adata.layers).intersection(obs_cols_to_keep) + layers_to_keep += set(adata.layers).intersection(var_cols_to_keep) + + obs_cols_to_keep += set(adata.obs.columns).intersection(var_cols_to_keep) + obs_cols_to_keep += set(adata.obs.columns).intersection(layers_to_keep) + + var_cols_to_keep += set(adata.var.columns).intersection(obs_cols_to_keep) + obs_cols_to_keep += set(adata.var.columns).intersection(layers_to_keep) + + returned_adata = cleanup( + adata, + keep=layers_to_keep + obs_cols_to_keep + var_cols_to_keep, + clean="all", + copy=copy, + ) + + if copy: + assert isinstance(returned_adata, AnnData) + adata = returned_adata + else: + assert returned_adata is None + + assert set(adata.layers.keys()) == set(layers_to_keep).difference({"X"}) + assert set(adata.obs.columns) == set(obs_cols_to_keep) + assert set(adata.var.columns) == set(var_cols_to_keep) + + class TestGetModality(TestBase): @given(adata=get_adata()) def test_get_modality(self, adata: AnnData):