-
Notifications
You must be signed in to change notification settings - Fork 597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle zero-weight Gaussians correctly in VariantRecalibrator #6425
Conversation
@davidbenjamin want to increase the VQSR bus number? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ldgauthier It looks good and I see how it resolves an edge case, but I have a couple of questions before the bus number can really be said to have inched up.
pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + MathUtils.normalDistributionLog10(gaussian.mu[iii], gaussian.sigma.get(iii, iii), datum.annotations[iii]); | ||
pVarInGaussianLog10[gaussianIndex] = gaussian.pMixtureLog10; | ||
if (gaussian.pMixtureLog10 != Double.NEGATIVE_INFINITY) { | ||
pVarInGaussianLog10[gaussianIndex] += MathUtils.normalDistributionLog10(gaussian.mu[iii], gaussian.sigma.get(iii, iii), datum.annotations[iii]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me why this special case is needed. If it's negative infinity, then adding the normalDistributionLog10
to it yields negative infinity. You avoid calculating it, of course, but presumably this is an edge case and the optimization doesn't matter.
My other guess is that sigma
or mu
could be NaNs somehow? If so, why is this connected to zero weight?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sigma and mu are in fact NaNs. I'm not entirely sure where that happens in the code, but I can look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this typically happens when the effective number (i.e., the sum of responsibilities) for the component goes to zero; if so, you can just add an epsilon to the denominators in the updates for the means and covariances to avoid this. Perhaps see my notebook from an ancient MIA primer for some pointers.
Any plans to update this implementation? See a few of my objections in #2062.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I understand what's going on (that bus number is on the move now). When a Gaussian has no data assigned the M step MultivariateGaussian.maximizeGaussian
ends up with zero matrices for pVarSigma
and wishart
, and thus the resulting sigma
equals the empiricalSigma
, which is just the whole-data (unclustered) covariance matrix.
Since, one of the annotations is constant and has zero variance, this empiricalSigma
is degenerate. Since the clusters with data assigned have non-zero wishart
and pVarSigma
, only the empty cluster has a problem.
I'm pretty sure that this is wrong, because variational Bayes should regularize the tendency of an empty cluster to take on a degenerate value -- that's the role of the prior. We could fix it by reconciling the code with Chapter 10 of Bishop (at the cost of several days of dragon-slaying), or we could just re-randomly initialize sigma for an empty cluster, but I would like to the fix the problem at its source so that we avoid creating a pathological Gaussian int he first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the basic idea of the Bishop approach of dealing with the singularity (mentioned at the end of 10.2.1), but I don't have any confidence in being able to work out the new closed-form solution. Maybe I can get some help offline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a number of reasons that I agree it's just standard Bishop. To fully generalize, I think all that has to happen is to uncouple the actual data from the rest of the VariantDatum
object, which I can probably do in a not unreasonable amount of time. I think I modified that class in the past. I might also be able to fully uncouple the model training from the inference, which I did in a crufty way in the past. And improve on the memory requirements...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not yet convinced that our implementation is faithful to the variational Bayes GMM. Bishop 10.6 factorizes down to a joint distribution on mu and sigma, but it doesn't split any further, for one thing. And even if it did, our GMM has a point estimate of sigma for each cluster, which is not the same as a VB M step that computes a posterior. Finally, I couldn't find anything in Bishop that resembles our empiricalSigma
, which is the covariance of all the data.
Note that one might worry that Bishop's joint posterior on mu and sigma -- a Gauss-Wishart distribution -- would be intractable for VQSR because the predictive densities could not be computed, but this is not the case. As Bishop shows, this is just a Student's t.
This is not to say that I think we should go after this dragon, but I think it's worth settling whether our implementation is actually a true VB GMM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't empiricalMu
and empiricalSigma
determine the respective priors? From a glance, it looks like they are initialized to the origin and some magic number times the identity matrix, respectively, and then never updated. You might be right about the implementation not being exactly faithful, but I think I concluded that it probably attempted to be.
I think that the difficulty of maintaining or improving this code arises because 1) the intended implementation is not clearly documented and 2) the code is not documented or easily parsed as the corresponding mathematical expressions. This is further compounded by the need for the relatively trivial extraction and refactoring mentioned previously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point estimate you mention (i.e., the one computed by evaluateFinalModelParameters
) seems to be somewhat arbitrarily chosen, but I think that the updates for the variational parameters calculated in maximizeGaussian
do seem to follow 10.58 and 10.60-63.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noting for future reference that I went digging into the original GATK3 code (the commit history over there seems corrupted...not sure if something went wrong at some point, perhaps during a public/protected split?). It looks like empiricalMu
and empiricalSigma
were indeed originally initialized to the mean and covariance of the data (although the code for the latter was commented out), although I'm not sure what the justification for this was. In broadgsa/gatk@3224bbe these were changed to represent the quantities for the Bishop prior, but the variable names were not changed.
@@ -406,6 +406,8 @@ public void testVariantRecalibratorSNPscattered(final String[] params) throws IO | |||
doSNPTest(params, getLargeVQSRTestDataDir() + "/snpTranches.scattered.txt", getLargeVQSRTestDataDir() + "snpRecal.vcf"); //tranches file isn't in the expected/ directory because it's input to GatherTranchesIntegrationTest | |||
} | |||
|
|||
//One of the 8 positive Gaussians (index 4) has weight zero, which is fine, but the one at index 2 has a covariance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything in the data or arguments from which we expect a priori that these errors will occur, and at the given indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the input data AS_MQrankSum is constant, which I note where the argument string is defined, but I could make that more clear. Specifying the index of the Gaussian probably isn't informative since anything that touches the random number generator has the potential to change that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case a comment here that the constantness (constancy?) of AS_MQrankSum causes one Gaussian to have a degenerate covariance would be helpful. I would either not even mention that it's the one with index 2 or be explicit that this is arbitrary (likewise for the zero-weight Gaussian at index 4).
one or more Gaussians to zero (hence *MAX* Gaussians), but this causes problems in log10 space
67fd85d
to
f54aafd
Compare
@davidbenjamin removing the final Mish-step in "evaluateFinalModelParameters" like we talked about made a huge difference in the VQSLODs in the test data. I'm going to continue looking into it, but I really want to get this fix for production into this week's release. |
@ldgauthier I can live with that. 👍 to the quick fix. |
@davidbenjamin can I get a review +1 so I can get this into the release? |
Yes. |
@ldgauthier @davidbenjamin any decision about what further improvements are worth pursuing? Looking for some quick and easy things to look into now that I'm back. In any case, are there any baseline runtime/memory estimates? How about typical numbers of variants and annotation dimensions that the tool is run with? Just curious about a quick comparison with sklearn, etc. |
It's on my list. Pretty near the bottom, but it's there. Runtime is probably getting up near an hour for big jobs. The memory requirements are horrific, because we load all the variants into memory and then we don't even use them all! For the biggest cohorts we use 104GB. I wish I was joking. If sklearn can minibatch GMMs then that would be amazing. We use a maximum of 2.5M variants for training and number of annotations/dimensions is O(10). The smallest exome cohort would train with about 80,000 variants with about 3GB of memory. That being said this definitely isn't the biggest cost contributor for joint calling, and hopefully all the sporadic failures have been hammered out. |
Did a quick test with sklearn's BayesianGaussianMixture, fitting 8 components to 2.5M 10D points generated from 4 isotropic blobs. On a Google Colab instance (which I believe are n1-highmem-2s), 150 iterations (which I think is the current maximum) completed in 14 minutes, with Note that convergence within the default tolerance isn't actually reached in 150 iterations for this toy data (as usual, it takes a while for the weights of unused components to shrink to zero). In any case, we'd have to compare against the number of iterations currently required to converge with the real data (and perhaps also check that the convergence criteria match up) to get a better idea of real runtime. Various tweaks to priors or other runtime options (such as k-means vs. random initialization) could also affect convergence speed. Minibatching isn't built in, but I think it should be pretty trivial to hack together something with the |
Doing a warm start with 1% of the toy data followed by a full fit takes <30s from start to finish and convergence is reached on the full fit in less than 10 iterations. We'd probably have to do some testing on real data to make sure we don't lose any important clusters with low weight using such a strategy (or some variation thereof), but in any case I think we can come in well under 1hr and 104GB... |
That all sounds really good. |
No description provided.