Skip to content

Commit

Permalink
fix: MrVI MixtureSameFamily log probability fixed (#3189)
Browse files Browse the repository at this point in the history
Fixes #3188.

This PR fixes the issue in both the `get_aggregated_posterior` function
as well as the Mixture of Gaussians prior.

Tested on tutorial: integration remains the same, differential abundance
results now take a much larger range as expected (see issue). To
mitigate spikiness in favor of the sample of origin's covariate, we
added the default omit_original_sample=True to avoid incorporating the
sample of origin's log prob into the calculation.

CC @rastogiruchir

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
justjhong and pre-commit-ci[bot] authored Feb 23, 2025
1 parent a9e45bf commit 1a026f8
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 26 deletions.
65 changes: 48 additions & 17 deletions src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,11 @@ def get_aggregated_posterior(
-------
A mixture distribution of the aggregated posterior.
"""
from numpyro.distributions import Categorical, MixtureSameFamily, Normal
from numpyro.distributions import (
Categorical,
MixtureSameFamily,
Normal,
)

self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)
Expand All @@ -806,11 +810,11 @@ def get_aggregated_posterior(
qu_locs.append(outputs["qu"].loc)
qu_scales.append(outputs["qu"].scale)

qu_loc = jnp.concatenate(qu_locs, axis=0).T
qu_scale = jnp.concatenate(qu_scales, axis=0).T
qu_loc = jnp.concatenate(qu_locs, axis=0) # n_cells x n_latent_u
qu_scale = jnp.concatenate(qu_scales, axis=0) # n_cells x n_latent_u
return MixtureSameFamily(
Categorical(probs=jnp.ones(qu_loc.shape[1]) / qu_loc.shape[1]),
Normal(qu_loc, qu_scale),
Categorical(probs=jnp.ones(qu_loc.shape[0]) / qu_loc.shape[0]),
(Normal(qu_loc, qu_scale).to_event(1)),
)

def differential_abundance(
Expand All @@ -819,6 +823,7 @@ def differential_abundance(
sample_cov_keys: list[str] | None = None,
sample_subset: list[str] | None = None,
compute_log_enrichment: bool = False,
omit_original_sample: bool = True,
batch_size: int = 128,
) -> xr.Dataset:
"""Compute the differential abundance between samples.
Expand All @@ -839,6 +844,9 @@ def differential_abundance(
Only compute differential abundance for these sample labels.
compute_log_enrichment
Whether to compute the log enrichment scores for each covariate value.
omit_original_sample
If true, each cell's sample-of-origin is discarded to compute aggregate posteriors.
Only relevant if sample_cov_keys is not None.
batch_size
Minibatch size for computing the differential abundance.
Expand Down Expand Up @@ -883,7 +891,7 @@ def differential_abundance(
n_splits = max(adata.n_obs // batch_size, 1)
log_probs_ = []
for u_rep in np.array_split(us, n_splits):
log_probs_.append(jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True)))
log_probs_.append(jax.device_get(ap.log_prob(u_rep))[..., np.newaxis])

log_probs.append(np.concatenate(log_probs_, axis=0)) # (n_cells, 1)

Expand All @@ -901,6 +909,23 @@ def differential_abundance(
if sample_cov_keys is None or len(sample_cov_keys) == 0:
return log_probs_arr

def aggregate_log_probs(log_probs, samples, omit_original_sample=False):
sample_log_probs = log_probs.loc[
{"sample": samples}
].values # (n_cells, n_samples_in_group)
if omit_original_sample:
sample_one_hot = np.zeros((adata.n_obs, len(samples)))
for i, sample in enumerate(samples):
sample_one_hot[adata.obs[self.sample_key] == sample, i] = 1
log_probs_no_original = np.where(
sample_one_hot, -np.inf, sample_log_probs
) # virtually discards samples-of-origin from aggregate posteriors
return logsumexp(log_probs_no_original, axis=1) - np.log(
(1 - sample_one_hot).sum(axis=1)
)
else:
return logsumexp(sample_log_probs, axis=1) - np.log(sample_log_probs.shape[1])

sample_cov_log_probs_map = {}
sample_cov_log_enrichs_map = {}
for sample_cov_key in sample_cov_keys:
Expand All @@ -916,8 +941,11 @@ def differential_abundance(
if len(cov_samples) == 0:
continue

sel_log_probs = log_probs_arr.log_probs.loc[{"sample": cov_samples}]
val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log(sel_log_probs.shape[1])
val_log_probs = aggregate_log_probs(
log_probs_arr.log_probs,
cov_samples,
omit_original_sample=omit_original_sample,
)
per_val_log_probs[sample_cov_value] = val_log_probs

if compute_log_enrichment:
Expand All @@ -930,9 +958,10 @@ def differential_abundance(
stacklevel=2,
)
continue
rest_log_probs = log_probs_arr.log_probs.loc[{"sample": rest_samples}]
rest_val_log_probs = logsumexp(rest_log_probs, axis=1) - np.log(
rest_log_probs.shape[1]
rest_val_log_probs = aggregate_log_probs(
log_probs_arr.log_probs,
rest_samples,
omit_original_sample=omit_original_sample,
)
enrichment_scores = val_log_probs - rest_val_log_probs
per_val_log_enrichs[sample_cov_value] = enrichment_scores
Expand Down Expand Up @@ -1018,8 +1047,8 @@ def get_outlier_cell_sample_pairs(

ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs)
in_max_comp_log_probs = ap.component_distribution.log_prob(
np.expand_dims(adata_s.obsm["U"], ap.mixture_dim)
).sum(axis=1)
np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) # (n_cells_ap, 1, n_latent_dim)
) # (n_cells_ap, n_cells_ap)
log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs)

log_probs_ = []
Expand All @@ -1028,10 +1057,12 @@ def get_outlier_cell_sample_pairs(
log_probs_.append(
jax.device_get(
ap.component_distribution.log_prob(
np.expand_dims(u_rep, ap.mixture_dim)
) # (n_cells_batch, n_cells_ap, n_latent_dim)
.sum(axis=1) # (n_cells_batch, n_latent_dim)
.max(axis=1, keepdims=True) # (n_cells_batch, 1)
np.expand_dims(
u_rep, ap.mixture_dim
) # (n_cells_batch, 1, n_latent_dim)
).max( # (n_cells_batch, n_cells_ap)
axis=1, keepdims=True
) # (n_cells_batch, 1)
)
)

Expand Down
21 changes: 12 additions & 9 deletions src/scvi/external/mrvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ def __call__(
sample_covariate: jax.typing.ArrayLike,
training: bool | None = None,
) -> dist.Normal:
from scvi.external.mrvi._components import ConditionalNormalization, NormalDistOutputNN
from scvi.external.mrvi._components import (
ConditionalNormalization,
NormalDistOutputNN,
)

training = nn.merge_param("training", self.training, training)
x_feat = jnp.log1p(x)
Expand Down Expand Up @@ -424,10 +427,10 @@ def setup(self):
"u_prior_logits", nn.initializers.zeros, (u_prior_mixture_k,)
)
self.u_prior_means = self.param(
"u_prior_means", jax.random.normal, (u_dim, u_prior_mixture_k)
"u_prior_means", jax.random.normal, (u_prior_mixture_k, u_dim)
)
self.u_prior_scales = self.param(
"u_prior_scales", nn.initializers.zeros, (u_dim, u_prior_mixture_k)
"u_prior_scales", nn.initializers.zeros, (u_prior_mixture_k, u_dim)
)

@property
Expand Down Expand Up @@ -517,7 +520,9 @@ def generative(
10.0 * jax.nn.one_hot(label_index, self.n_labels) if self.n_labels >= 2 else 0.0
)
cats = dist.Categorical(logits=self.u_prior_logits + offset)
normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales))
normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales)).to_event(
1
)
pu = dist.MixtureSameFamily(cats, normal_dists)
else:
pu = dist.Normal(0, jnp.exp(self.u_prior_scale))
Expand All @@ -536,13 +541,11 @@ def loss(
)

if self.u_prior_mixture:
kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]) - generative_outputs[
"pu"
].log_prob(inference_outputs["u"])
kl_u = kl_u.sum(-1)
kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]).sum(
-1
) - generative_outputs["pu"].log_prob(inference_outputs["u"])
else:
kl_u = dist.kl_divergence(inference_outputs["qu"], generative_outputs["pu"]).sum(-1)
inference_outputs["qeps"]

kl_z = 0.0
eps = inference_outputs["z"] - inference_outputs["z_base"]
Expand Down
1 change: 1 addition & 0 deletions tests/external/mrvi/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_mrvi_de(model: MRVI, setup_kwargs: dict[str, Any], de_kwargs: dict[str,
[
{"sample_cov_keys": ["meta1_cat"]},
{"sample_cov_keys": ["meta1_cat", "batch"]},
{"sample_cov_keys": ["meta1_cat"], "omit_original_sample": False},
{"sample_cov_keys": ["meta1_cat"], "compute_log_enrichment": True},
{"sample_cov_keys": ["meta1_cat", "batch"], "compute_log_enrichment": True},
],
Expand Down

0 comments on commit 1a026f8

Please sign in to comment.