Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MixtureSameFamily log probability computations as described in #3188 #3189

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

justjhong
Copy link
Contributor

@justjhong justjhong commented Feb 13, 2025

Fixes #3188.

This PR fixes the issue in both the get_aggregated_posterior function as well as the Mixture of Gaussians prior.

Tested on tutorial: integration remains the same, differential abundance results now take a much larger range as expected (see issue). To mitigate spikiness in favor of the sample of origin's covariate, we added the default omit_original_sample=True to avoid incorporating the sample of origin's log prob into the calculation.

CC @rastogiruchir

@justjhong justjhong added optional tests Run optional tests cuda tests Run test suite on CUDA labels Feb 13, 2025
Copy link

codecov bot commented Feb 13, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 86.11%. Comparing base (c3926eb) to head (798eee6).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3189      +/-   ##
==========================================
- Coverage   89.27%   86.11%   -3.16%     
==========================================
  Files         185      185              
  Lines       16265    16270       +5     
==========================================
- Hits        14520    14011     -509     
- Misses       1745     2259     +514     
Files with missing lines Coverage Δ
src/scvi/external/mrvi/_model.py 89.17% <100.00%> (+0.14%) ⬆️
src/scvi/external/mrvi/_module.py 96.01% <100.00%> (-0.04%) ⬇️

... and 19 files with indirect coverage changes

@justjhong justjhong changed the title Fix aggregated posterior computation as described in #3188 Fix MixtureSameFamily log probability computations as described in #3188 Feb 13, 2025
@justjhong justjhong removed the cuda tests Run test suite on CUDA label Feb 13, 2025
@ori-kron-wis ori-kron-wis added the on-merge: backport to 1.2.x on-merge: backport to 1.2.x label Feb 17, 2025
Copy link
Contributor

@PierreBoyeau PierreBoyeau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -517,7 +520,9 @@ def generative(
10.0 * jax.nn.one_hot(label_index, self.n_labels) if self.n_labels >= 2 else 0.0
)
cats = dist.Categorical(logits=self.u_prior_logits + offset)
normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales))
normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales)).to_event(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

snippet to double-check it works in loss (it works):

        pu = generative_outputs["pu"]
        
        # approach 1: below is the reference implementation of the log_prob of the mixture distribution
        log_pu = generative_outputs["pu"].log_prob(inference_outputs["u"])
        
        # approach 2: manual computation of the log_prob of the mixture distribution
        mixing_distribution = pu._mixing_distribution
        component_distribution = pu._component_distribution
        pk = mixing_distribution.probs
        mus = component_distribution.base_dist.loc
        stds = component_distribution.base_dist.scale
        
        pu_k = dist.Normal(mus, stds)
        log_pu_k = pu_k.log_prob(inference_outputs["u"][:, None, :]).sum(-1)
        log_pk = jnp.log(pk)
        log_puk = log_pk + log_pu_k
        # above are joint probas
        # shape (n_cells, n_mixtures)
        
        log_prob_u = jax.scipy.special.logsumexp(log_puk, axis=1)  # shape (n_cells,)
        # (log_prob_u == log_pu) ?

Copy link
Contributor

@PierreBoyeau PierreBoyeau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, see comments!

@PierreBoyeau
Copy link
Contributor

Quick observations:

  • These changes do not seem to affect the performed analyses conducted in the paper. Here are the log ratios (COVID vs healthy) from the complete Haniffa dataset, showing a similar behavior as in the MrVI manuscript.
    image
  • That being said, these log ratios seem slightly different when computed on the drastically subsampled Haniffa dataset. If similar issues are reported on small datasets, we might want to explore options to smooth log ratios then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
on-merge: backport to 1.2.x on-merge: backport to 1.2.x optional tests Run optional tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Potential bug in MRVI's get_aggregated_posterior
3 participants