-
Notifications
You must be signed in to change notification settings - Fork 378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix MixtureSameFamily log probability computations as described in #3188 #3189
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3189 +/- ##
==========================================
- Coverage 89.27% 86.11% -3.16%
==========================================
Files 185 185
Lines 16265 16270 +5
==========================================
- Hits 14520 14011 -509
- Misses 1745 2259 +514
|
c40c156
to
1660bbb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -517,7 +520,9 @@ def generative( | |||
10.0 * jax.nn.one_hot(label_index, self.n_labels) if self.n_labels >= 2 else 0.0 | |||
) | |||
cats = dist.Categorical(logits=self.u_prior_logits + offset) | |||
normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales)) | |||
normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales)).to_event( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
snippet to double-check it works in loss
(it works):
pu = generative_outputs["pu"]
# approach 1: below is the reference implementation of the log_prob of the mixture distribution
log_pu = generative_outputs["pu"].log_prob(inference_outputs["u"])
# approach 2: manual computation of the log_prob of the mixture distribution
mixing_distribution = pu._mixing_distribution
component_distribution = pu._component_distribution
pk = mixing_distribution.probs
mus = component_distribution.base_dist.loc
stds = component_distribution.base_dist.scale
pu_k = dist.Normal(mus, stds)
log_pu_k = pu_k.log_prob(inference_outputs["u"][:, None, :]).sum(-1)
log_pk = jnp.log(pk)
log_puk = log_pk + log_pu_k
# above are joint probas
# shape (n_cells, n_mixtures)
log_prob_u = jax.scipy.special.logsumexp(log_puk, axis=1) # shape (n_cells,)
# (log_prob_u == log_pu) ?
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, see comments!
Fixes #3188.
This PR fixes the issue in both the
get_aggregated_posterior
function as well as the Mixture of Gaussians prior.Tested on tutorial: integration remains the same, differential abundance results now take a much larger range as expected (see issue). To mitigate spikiness in favor of the sample of origin's covariate, we added the default omit_original_sample=True to avoid incorporating the sample of origin's log prob into the calculation.
CC @rastogiruchir