diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 0cbdcb6cd..28bf5f093 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -720,3 +720,90 @@ def compute_entropy( if key_added is not None: self.adata.obs[key_added] = df return df if key_added is None else None + + def compute_variance( + self: AnalysisMixinProtocol[K, B], + source: K, + target: K, + forward: bool = True, + latent_space_selection: Union[str, list[str]] = "X_pca", + key_added: Optional[str] = "conditional_variance", + batch_size: Optional[int] = None, + ) -> Optional[pd.DataFrame]: + """Compute the conditional variance per cell. + + The conditional variance reflects the uncertainty of the mapping of a single cell by taking into account + a given latent space representation of all cells. + + Parameters + ---------- + source + Source key. + target + Target key. + forward + If `True`, computes the conditional variance given a cell in the source distribution, else the + conditional variance given a cell in the target distribution. + latent_space_selection: + Key or Keys which specifies the latent or feature space used for computing the conditional variance. + A single key has to be a latent space in :attr:`~anndata.AnnData.obsm` or + a gene in :attr:`~anndata.AnnData.var_names`. + A set of keys has to be a subset of genes in :attr:`~anndata.AnnData.var_names`. + key_added + Key in :attr:`~anndata.AnnData.obs` where the variance is stored. + batch_size + Batch size for the computation of the variance. If :obj:`None`, the entire dataset is used. + + Returns + ------- + :obj:`None` if ``key_added`` is not None. Otherwise, returns a data frame of shape ``(n_cells, 1)`` containing + the conditional variance given each cell. + """ + filter_value = source if forward else target + opposite_filter_value = target if forward else source + + if isinstance(latent_space_selection, str): + if latent_space_selection in self.adata.obsm: + latent_space = self.adata.obsm[latent_space_selection] + elif latent_space_selection in self.adata.var_names: + latent_space = self.adata[:, latent_space_selection in self.adata.var_names].X.toarray() + else: + raise KeyError("Gene/Latent space not found.") + elif type(latent_space_selection) in [list, np.ndarray]: + mask = [var_name in latent_space_selection for var_name in self.adata.var_names] + latent_space = self.adata[:, mask].X.toarray() + else: + raise KeyError("Unknown latent space selection.") + + latent_space_filtered = latent_space[np.array(self.adata.obs[self._policy.key] == opposite_filter_value), :] + + df = pd.DataFrame( + index=self.adata[self.adata.obs[self._policy.key] == filter_value, :].obs_names, + columns=[key_added] if key_added is not None else ["variance"], + ) + + batch_size = batch_size if batch_size is not None else len(df) + func = self.push if forward else self.pull + for batch in range(0, len(df), batch_size): + cond_dists = func( + source=source, + target=target, + data=None, + subset=(batch, batch_size), + normalize=True, + return_all=False, + scale_by_marginals=False, + split_mass=True, + key_added=None, + ) + + cond_var = [] + for i in range(cond_dists.shape[1]): # type: ignore[union-attr] + expected_val = (cond_dists[:, i]).reshape(-1, 1) * latent_space_filtered # type: ignore[index] + cond_var.append(np.linalg.norm((latent_space_filtered - expected_val), axis=1) ** 2 @ cond_dists[:, i]) # type: ignore[index] + + df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = np.array(cond_var) + + if key_added is not None: + self.adata.obs[key_added] = df + return df if key_added is None else None diff --git a/tests/problems/generic/test_mixins.py b/tests/problems/generic/test_mixins.py index 85c702c63..afab3030b 100644 --- a/tests/problems/generic/test_mixins.py +++ b/tests/problems/generic/test_mixins.py @@ -336,6 +336,45 @@ def test_compute_entropy_regression(self, adata_time: AnnData, forward: bool, ba np.array(moscot_out, dtype=float), np.array(gt_out, dtype=float), rtol=RTOL, atol=ATOL ) + @pytest.mark.parametrize("forward", [True, False]) + @pytest.mark.parametrize("key_added", [None, "test"]) + @pytest.mark.parametrize("batch_size", [None, 2]) + @pytest.mark.parametrize("latent_space_selection", ["X_pca", "KLF12", ["KLF12", "Dlip3", "Dref"]]) + def test_compute_variance_pipeline( + self, adata_time: AnnData, forward: bool, latent_space_selection, key_added: Optional[str], batch_size: int + ): + rng = np.random.RandomState(42) + adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() + n0 = adata_time[adata_time.obs["time"] == 0].n_obs + n1 = adata_time[adata_time.obs["time"] == 1].n_obs + + tmap = rng.uniform(1e-6, 1, size=(n0, n1)) + tmap /= tmap.sum().sum() + problem = CompoundProblemWithMixin(adata_time) + problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential") + problem[0, 1]._solution = MockSolverOutput(tmap) + + out = problem.compute_variance( + source=0, + target=1, + forward=forward, + key_added=key_added, + latent_space_selection=latent_space_selection, + batch_size=batch_size, + ) + if key_added is None: + assert isinstance(out, pd.DataFrame) + assert len(out) == n0 + else: + assert out is None + assert key_added in adata_time.obs + assert np.sum(adata_time[adata_time.obs["time"] == int(1 - forward)].obs[key_added].isna()) == 0 + assert ( + np.sum(adata_time[adata_time.obs["time"] == int(forward)].obs[key_added].isna()) == n1 + if forward + else n0 + ) + def test_seed_reproducible(self, adata_time: AnnData): key_added = "test" rng = np.random.RandomState(42)