Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conditional variance #712

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,90 @@ def compute_entropy(
if key_added is not None:
self.adata.obs[key_added] = df
return df if key_added is None else None

def compute_variance(
self: AnalysisMixinProtocol[K, B],
source: K,
target: K,
forward: bool = True,
latent_space_selection: Union[str, list[str]] = "X_pca",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not call it latent space. can also be raw space, e.g. gene space. Also, the types are not clear

key_added: Optional[str] = "conditional_variance",
batch_size: Optional[int] = None,
) -> Optional[pd.DataFrame]:
"""Compute the conditional variance per cell.

The conditional variance reflects the uncertainty of the mapping of a single cell by taking into account
a given latent space representation of all cells.

Parameters
----------
source
Source key.
target
Target key.
forward
If `True`, computes the conditional variance given a cell in the source distribution, else the
conditional variance given a cell in the target distribution.
latent_space_selection:
Key or Keys which specifies the latent or feature space used for computing the conditional variance.
A single key has to be a latent space in :attr:`~anndata.AnnData.obsm` or
a gene in :attr:`~anndata.AnnData.var_names`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a feature in ..., because we might also store proteins/ATAC, etc.

A set of keys has to be a subset of genes in :attr:`~anndata.AnnData.var_names`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hinting doesn't say set, but list.

key_added
Key in :attr:`~anndata.AnnData.obs` where the variance is stored.
batch_size
Batch size for the computation of the variance. If :obj:`None`, the entire dataset is used.

Returns
-------
:obj:`None` if ``key_added`` is not None. Otherwise, returns a data frame of shape ``(n_cells, 1)`` containing
the conditional variance given each cell.
"""
filter_value = source if forward else target
opposite_filter_value = target if forward else source

if isinstance(latent_space_selection, str):
if latent_space_selection in self.adata.obsm:
latent_space = self.adata.obsm[latent_space_selection]
elif latent_space_selection in self.adata.var_names:
latent_space = self.adata[:, latent_space_selection in self.adata.var_names].X.toarray()
else:
raise KeyError("Gene/Latent space not found.")
elif type(latent_space_selection) in [list, np.ndarray]:
mask = [var_name in latent_space_selection for var_name in self.adata.var_names]
latent_space = self.adata[:, mask].X.toarray()
else:
raise KeyError("Unknown latent space selection.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we have a key error, we want to print what the wrong key is.

Comment on lines +762 to +776
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a function (within the function)


latent_space_filtered = latent_space[np.array(self.adata.obs[self._policy.key] == opposite_filter_value), :]

df = pd.DataFrame(
index=self.adata[self.adata.obs[self._policy.key] == filter_value, :].obs_names,
columns=[key_added] if key_added is not None else ["variance"],
)

batch_size = batch_size if batch_size is not None else len(df)
func = self.push if forward else self.pull
for batch in range(0, len(df), batch_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we actually do this ? :)

cond_dists = func(
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
return_all=False,
scale_by_marginals=False,
split_mass=True,
key_added=None,
)

cond_var = []
for i in range(cond_dists.shape[1]): # type: ignore[union-attr]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we vectorize this?

expected_val = (cond_dists[:, i]).reshape(-1, 1) * latent_space_filtered # type: ignore[index]
cond_var.append(np.linalg.norm((latent_space_filtered - expected_val), axis=1) ** 2 @ cond_dists[:, i]) # type: ignore[index]

df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = np.array(cond_var)

if key_added is not None:
self.adata.obs[key_added] = df
return df if key_added is None else None
39 changes: 39 additions & 0 deletions tests/problems/generic/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,45 @@ def test_compute_entropy_regression(self, adata_time: AnnData, forward: bool, ba
np.array(moscot_out, dtype=float), np.array(gt_out, dtype=float), rtol=RTOL, atol=ATOL
)

@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("key_added", [None, "test"])
@pytest.mark.parametrize("batch_size", [None, 2])
@pytest.mark.parametrize("latent_space_selection", ["X_pca", "KLF12", ["KLF12", "Dlip3", "Dref"]])
def test_compute_variance_pipeline(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also check for raise Error with wrong attributes.

self, adata_time: AnnData, forward: bool, latent_space_selection, key_added: Optional[str], batch_size: int
):
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
n0 = adata_time[adata_time.obs["time"] == 0].n_obs
n1 = adata_time[adata_time.obs["time"] == 1].n_obs

tmap = rng.uniform(1e-6, 1, size=(n0, n1))
tmap /= tmap.sum().sum()
problem = CompoundProblemWithMixin(adata_time)
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
problem[0, 1]._solution = MockSolverOutput(tmap)

out = problem.compute_variance(
source=0,
target=1,
forward=forward,
key_added=key_added,
latent_space_selection=latent_space_selection,
batch_size=batch_size,
)
if key_added is None:
assert isinstance(out, pd.DataFrame)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check for some properties, e.g. no NaN, non-negativity

assert len(out) == n0
else:
assert out is None
assert key_added in adata_time.obs
assert np.sum(adata_time[adata_time.obs["time"] == int(1 - forward)].obs[key_added].isna()) == 0
assert (
np.sum(adata_time[adata_time.obs["time"] == int(forward)].obs[key_added].isna()) == n1
if forward
else n0
)

def test_seed_reproducible(self, adata_time: AnnData):
key_added = "test"
rng = np.random.RandomState(42)
Expand Down
Loading