Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MixtureSameFamily log probability computations as described in #3188 #3189

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,11 @@ def get_aggregated_posterior(
-------
A mixture distribution of the aggregated posterior.
"""
from numpyro.distributions import Categorical, MixtureSameFamily, Normal
from numpyro.distributions import (
Categorical,
MixtureSameFamily,
Normal,
)

self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)
Expand All @@ -806,11 +810,11 @@ def get_aggregated_posterior(
qu_locs.append(outputs["qu"].loc)
qu_scales.append(outputs["qu"].scale)

qu_loc = jnp.concatenate(qu_locs, axis=0).T
qu_scale = jnp.concatenate(qu_scales, axis=0).T
qu_loc = jnp.concatenate(qu_locs, axis=0) # n_cells x n_latent_u
qu_scale = jnp.concatenate(qu_scales, axis=0) # n_cells x n_latent_u
return MixtureSameFamily(
Categorical(probs=jnp.ones(qu_loc.shape[1]) / qu_loc.shape[1]),
Normal(qu_loc, qu_scale),
Categorical(probs=jnp.ones(qu_loc.shape[0]) / qu_loc.shape[0]),
(Normal(qu_loc, qu_scale).to_event(1)),
)

def differential_abundance(
Expand All @@ -819,6 +823,7 @@ def differential_abundance(
sample_cov_keys: list[str] | None = None,
sample_subset: list[str] | None = None,
compute_log_enrichment: bool = False,
omit_original_sample: bool = True,
batch_size: int = 128,
) -> xr.Dataset:
"""Compute the differential abundance between samples.
Expand All @@ -839,6 +844,9 @@ def differential_abundance(
Only compute differential abundance for these sample labels.
compute_log_enrichment
Whether to compute the log enrichment scores for each covariate value.
omit_original_sample
If true, each cell's sample-of-origin is discarded to compute aggregate posteriors.
Only relevant if sample_cov_keys is not None.
batch_size
Minibatch size for computing the differential abundance.

Expand Down Expand Up @@ -883,7 +891,7 @@ def differential_abundance(
n_splits = max(adata.n_obs // batch_size, 1)
log_probs_ = []
for u_rep in np.array_split(us, n_splits):
log_probs_.append(jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True)))
log_probs_.append(jax.device_get(ap.log_prob(u_rep))[..., np.newaxis])

log_probs.append(np.concatenate(log_probs_, axis=0)) # (n_cells, 1)

Expand All @@ -901,6 +909,23 @@ def differential_abundance(
if sample_cov_keys is None or len(sample_cov_keys) == 0:
return log_probs_arr

def aggregate_log_probs(log_probs, samples, omit_original_sample=False):
sample_log_probs = log_probs.loc[
{"sample": samples}
].values # (n_cells, n_samples_in_group)
if omit_original_sample:
sample_one_hot = np.zeros((adata.n_obs, len(samples)))
for i, sample in enumerate(samples):
sample_one_hot[adata.obs[self.sample_key] == sample, i] = 1
log_probs_no_original = np.where(
sample_one_hot, -np.inf, sample_log_probs
) # virtually discards samples-of-origin from aggregate posteriors
return logsumexp(log_probs_no_original, axis=1) - np.log(
(1 - sample_one_hot).sum(axis=1)
)
else:
return logsumexp(sample_log_probs, axis=1) - np.log(sample_log_probs.shape[1])

sample_cov_log_probs_map = {}
sample_cov_log_enrichs_map = {}
for sample_cov_key in sample_cov_keys:
Expand All @@ -916,8 +941,11 @@ def differential_abundance(
if len(cov_samples) == 0:
continue

sel_log_probs = log_probs_arr.log_probs.loc[{"sample": cov_samples}]
val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log(sel_log_probs.shape[1])
val_log_probs = aggregate_log_probs(
log_probs_arr.log_probs,
cov_samples,
omit_original_sample=omit_original_sample,
)
per_val_log_probs[sample_cov_value] = val_log_probs

if compute_log_enrichment:
Expand All @@ -930,9 +958,10 @@ def differential_abundance(
stacklevel=2,
)
continue
rest_log_probs = log_probs_arr.log_probs.loc[{"sample": rest_samples}]
rest_val_log_probs = logsumexp(rest_log_probs, axis=1) - np.log(
rest_log_probs.shape[1]
rest_val_log_probs = aggregate_log_probs(
log_probs_arr.log_probs,
rest_samples,
omit_original_sample=omit_original_sample,
)
enrichment_scores = val_log_probs - rest_val_log_probs
per_val_log_enrichs[sample_cov_value] = enrichment_scores
Expand Down Expand Up @@ -1018,8 +1047,8 @@ def get_outlier_cell_sample_pairs(

ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs)
in_max_comp_log_probs = ap.component_distribution.log_prob(
np.expand_dims(adata_s.obsm["U"], ap.mixture_dim)
).sum(axis=1)
np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) # (n_cells_ap, 1, n_latent_dim)
) # (n_cells_ap, n_cells_ap)
log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs)

log_probs_ = []
Expand All @@ -1028,10 +1057,12 @@ def get_outlier_cell_sample_pairs(
log_probs_.append(
jax.device_get(
ap.component_distribution.log_prob(
np.expand_dims(u_rep, ap.mixture_dim)
) # (n_cells_batch, n_cells_ap, n_latent_dim)
.sum(axis=1) # (n_cells_batch, n_latent_dim)
.max(axis=1, keepdims=True) # (n_cells_batch, 1)
np.expand_dims(
u_rep, ap.mixture_dim
) # (n_cells_batch, 1, n_latent_dim)
).max( # (n_cells_batch, n_cells_ap)
axis=1, keepdims=True
) # (n_cells_batch, 1)
)
)

Expand Down
21 changes: 12 additions & 9 deletions src/scvi/external/mrvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ def __call__(
sample_covariate: jax.typing.ArrayLike,
training: bool | None = None,
) -> dist.Normal:
from scvi.external.mrvi._components import ConditionalNormalization, NormalDistOutputNN
from scvi.external.mrvi._components import (
ConditionalNormalization,
NormalDistOutputNN,
)

training = nn.merge_param("training", self.training, training)
x_feat = jnp.log1p(x)
Expand Down Expand Up @@ -424,10 +427,10 @@ def setup(self):
"u_prior_logits", nn.initializers.zeros, (u_prior_mixture_k,)
)
self.u_prior_means = self.param(
"u_prior_means", jax.random.normal, (u_dim, u_prior_mixture_k)
"u_prior_means", jax.random.normal, (u_prior_mixture_k, u_dim)
)
self.u_prior_scales = self.param(
"u_prior_scales", nn.initializers.zeros, (u_dim, u_prior_mixture_k)
"u_prior_scales", nn.initializers.zeros, (u_prior_mixture_k, u_dim)
)

@property
Expand Down Expand Up @@ -517,7 +520,9 @@ def generative(
10.0 * jax.nn.one_hot(label_index, self.n_labels) if self.n_labels >= 2 else 0.0
)
cats = dist.Categorical(logits=self.u_prior_logits + offset)
normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales))
normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales)).to_event(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

snippet to double-check it works in loss (it works):

        pu = generative_outputs["pu"]
        
        # approach 1: below is the reference implementation of the log_prob of the mixture distribution
        log_pu = generative_outputs["pu"].log_prob(inference_outputs["u"])
        
        # approach 2: manual computation of the log_prob of the mixture distribution
        mixing_distribution = pu._mixing_distribution
        component_distribution = pu._component_distribution
        pk = mixing_distribution.probs
        mus = component_distribution.base_dist.loc
        stds = component_distribution.base_dist.scale
        
        pu_k = dist.Normal(mus, stds)
        log_pu_k = pu_k.log_prob(inference_outputs["u"][:, None, :]).sum(-1)
        log_pk = jnp.log(pk)
        log_puk = log_pk + log_pu_k
        # above are joint probas
        # shape (n_cells, n_mixtures)
        
        log_prob_u = jax.scipy.special.logsumexp(log_puk, axis=1)  # shape (n_cells,)
        # (log_prob_u == log_pu) ?

1
)
pu = dist.MixtureSameFamily(cats, normal_dists)
else:
pu = dist.Normal(0, jnp.exp(self.u_prior_scale))
Expand All @@ -536,13 +541,11 @@ def loss(
)

if self.u_prior_mixture:
kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]) - generative_outputs[
"pu"
].log_prob(inference_outputs["u"])
kl_u = kl_u.sum(-1)
kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]).sum(
-1
) - generative_outputs["pu"].log_prob(inference_outputs["u"])
else:
kl_u = dist.kl_divergence(inference_outputs["qu"], generative_outputs["pu"]).sum(-1)
inference_outputs["qeps"]

kl_z = 0.0
eps = inference_outputs["z"] - inference_outputs["z_base"]
Expand Down
1 change: 1 addition & 0 deletions tests/external/mrvi/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_mrvi_de(model: MRVI, setup_kwargs: dict[str, Any], de_kwargs: dict[str,
[
{"sample_cov_keys": ["meta1_cat"]},
{"sample_cov_keys": ["meta1_cat", "batch"]},
{"sample_cov_keys": ["meta1_cat"], "omit_original_sample": False},
{"sample_cov_keys": ["meta1_cat"], "compute_log_enrichment": True},
{"sample_cov_keys": ["meta1_cat", "batch"], "compute_log_enrichment": True},
],
Expand Down
Loading