-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add conditional variance #712
base: main
Are you sure you want to change the base?
Changes from all commits
1710e4e
5baf9f6
fa72e23
5a392f5
3aebd31
b74fe9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -720,3 +720,90 @@ def compute_entropy( | |
if key_added is not None: | ||
self.adata.obs[key_added] = df | ||
return df if key_added is None else None | ||
|
||
def compute_variance( | ||
self: AnalysisMixinProtocol[K, B], | ||
source: K, | ||
target: K, | ||
forward: bool = True, | ||
latent_space_selection: Union[str, list[str]] = "X_pca", | ||
key_added: Optional[str] = "conditional_variance", | ||
batch_size: Optional[int] = None, | ||
) -> Optional[pd.DataFrame]: | ||
"""Compute the conditional variance per cell. | ||
|
||
The conditional variance reflects the uncertainty of the mapping of a single cell by taking into account | ||
a given latent space representation of all cells. | ||
|
||
Parameters | ||
---------- | ||
source | ||
Source key. | ||
target | ||
Target key. | ||
forward | ||
If `True`, computes the conditional variance given a cell in the source distribution, else the | ||
conditional variance given a cell in the target distribution. | ||
latent_space_selection: | ||
Key or Keys which specifies the latent or feature space used for computing the conditional variance. | ||
A single key has to be a latent space in :attr:`~anndata.AnnData.obsm` or | ||
a gene in :attr:`~anndata.AnnData.var_names`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a feature in ..., because we might also store proteins/ATAC, etc. |
||
A set of keys has to be a subset of genes in :attr:`~anndata.AnnData.var_names`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type hinting doesn't say set, but list. |
||
key_added | ||
Key in :attr:`~anndata.AnnData.obs` where the variance is stored. | ||
batch_size | ||
Batch size for the computation of the variance. If :obj:`None`, the entire dataset is used. | ||
|
||
Returns | ||
------- | ||
:obj:`None` if ``key_added`` is not None. Otherwise, returns a data frame of shape ``(n_cells, 1)`` containing | ||
the conditional variance given each cell. | ||
""" | ||
filter_value = source if forward else target | ||
opposite_filter_value = target if forward else source | ||
|
||
if isinstance(latent_space_selection, str): | ||
if latent_space_selection in self.adata.obsm: | ||
latent_space = self.adata.obsm[latent_space_selection] | ||
elif latent_space_selection in self.adata.var_names: | ||
latent_space = self.adata[:, latent_space_selection in self.adata.var_names].X.toarray() | ||
else: | ||
raise KeyError("Gene/Latent space not found.") | ||
elif type(latent_space_selection) in [list, np.ndarray]: | ||
mask = [var_name in latent_space_selection for var_name in self.adata.var_names] | ||
latent_space = self.adata[:, mask].X.toarray() | ||
else: | ||
raise KeyError("Unknown latent space selection.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we have a key error, we want to print what the wrong key is.
Comment on lines
+762
to
+776
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make this a function (within the function) |
||
|
||
latent_space_filtered = latent_space[np.array(self.adata.obs[self._policy.key] == opposite_filter_value), :] | ||
|
||
df = pd.DataFrame( | ||
index=self.adata[self.adata.obs[self._policy.key] == filter_value, :].obs_names, | ||
columns=[key_added] if key_added is not None else ["variance"], | ||
) | ||
|
||
batch_size = batch_size if batch_size is not None else len(df) | ||
func = self.push if forward else self.pull | ||
for batch in range(0, len(df), batch_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we actually do this ? :) |
||
cond_dists = func( | ||
source=source, | ||
target=target, | ||
data=None, | ||
subset=(batch, batch_size), | ||
normalize=True, | ||
return_all=False, | ||
scale_by_marginals=False, | ||
split_mass=True, | ||
key_added=None, | ||
) | ||
|
||
cond_var = [] | ||
for i in range(cond_dists.shape[1]): # type: ignore[union-attr] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we vectorize this? |
||
expected_val = (cond_dists[:, i]).reshape(-1, 1) * latent_space_filtered # type: ignore[index] | ||
cond_var.append(np.linalg.norm((latent_space_filtered - expected_val), axis=1) ** 2 @ cond_dists[:, i]) # type: ignore[index] | ||
|
||
df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = np.array(cond_var) | ||
|
||
if key_added is not None: | ||
self.adata.obs[key_added] = df | ||
return df if key_added is None else None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -336,6 +336,45 @@ def test_compute_entropy_regression(self, adata_time: AnnData, forward: bool, ba | |
np.array(moscot_out, dtype=float), np.array(gt_out, dtype=float), rtol=RTOL, atol=ATOL | ||
) | ||
|
||
@pytest.mark.parametrize("forward", [True, False]) | ||
@pytest.mark.parametrize("key_added", [None, "test"]) | ||
@pytest.mark.parametrize("batch_size", [None, 2]) | ||
@pytest.mark.parametrize("latent_space_selection", ["X_pca", "KLF12", ["KLF12", "Dlip3", "Dref"]]) | ||
def test_compute_variance_pipeline( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also check for |
||
self, adata_time: AnnData, forward: bool, latent_space_selection, key_added: Optional[str], batch_size: int | ||
): | ||
rng = np.random.RandomState(42) | ||
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() | ||
n0 = adata_time[adata_time.obs["time"] == 0].n_obs | ||
n1 = adata_time[adata_time.obs["time"] == 1].n_obs | ||
|
||
tmap = rng.uniform(1e-6, 1, size=(n0, n1)) | ||
tmap /= tmap.sum().sum() | ||
problem = CompoundProblemWithMixin(adata_time) | ||
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential") | ||
problem[0, 1]._solution = MockSolverOutput(tmap) | ||
|
||
out = problem.compute_variance( | ||
source=0, | ||
target=1, | ||
forward=forward, | ||
key_added=key_added, | ||
latent_space_selection=latent_space_selection, | ||
batch_size=batch_size, | ||
) | ||
if key_added is None: | ||
assert isinstance(out, pd.DataFrame) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check for some properties, e.g. no |
||
assert len(out) == n0 | ||
else: | ||
assert out is None | ||
assert key_added in adata_time.obs | ||
assert np.sum(adata_time[adata_time.obs["time"] == int(1 - forward)].obs[key_added].isna()) == 0 | ||
assert ( | ||
np.sum(adata_time[adata_time.obs["time"] == int(forward)].obs[key_added].isna()) == n1 | ||
if forward | ||
else n0 | ||
) | ||
|
||
def test_seed_reproducible(self, adata_time: AnnData): | ||
key_added = "test" | ||
rng = np.random.RandomState(42) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not call it latent space. can also be raw space, e.g. gene space. Also, the types are not clear