Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle zero-weight Gaussians correctly in VariantRecalibrator #6425

Merged
merged 1 commit into from
Feb 7, 2020

Conversation

ldgauthier
Copy link
Contributor

No description provided.

@ldgauthier
Copy link
Contributor Author

@davidbenjamin want to increase the VQSR bus number?

Copy link
Contributor

@davidbenjamin davidbenjamin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ldgauthier It looks good and I see how it resolves an edge case, but I have a couple of questions before the bus number can really be said to have inched up.

pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + MathUtils.normalDistributionLog10(gaussian.mu[iii], gaussian.sigma.get(iii, iii), datum.annotations[iii]);
pVarInGaussianLog10[gaussianIndex] = gaussian.pMixtureLog10;
if (gaussian.pMixtureLog10 != Double.NEGATIVE_INFINITY) {
pVarInGaussianLog10[gaussianIndex] += MathUtils.normalDistributionLog10(gaussian.mu[iii], gaussian.sigma.get(iii, iii), datum.annotations[iii]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me why this special case is needed. If it's negative infinity, then adding the normalDistributionLog10 to it yields negative infinity. You avoid calculating it, of course, but presumably this is an edge case and the optimization doesn't matter.

My other guess is that sigma or mu could be NaNs somehow? If so, why is this connected to zero weight?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sigma and mu are in fact NaNs. I'm not entirely sure where that happens in the code, but I can look.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this typically happens when the effective number (i.e., the sum of responsibilities) for the component goes to zero; if so, you can just add an epsilon to the denominators in the updates for the means and covariances to avoid this. Perhaps see my notebook from an ancient MIA primer for some pointers.

Any plans to update this implementation? See a few of my objections in #2062.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand what's going on (that bus number is on the move now). When a Gaussian has no data assigned the M step MultivariateGaussian.maximizeGaussian ends up with zero matrices for pVarSigma and wishart, and thus the resulting sigma equals the empiricalSigma, which is just the whole-data (unclustered) covariance matrix.

Since, one of the annotations is constant and has zero variance, this empiricalSigma is degenerate. Since the clusters with data assigned have non-zero wishart and pVarSigma, only the empty cluster has a problem.

I'm pretty sure that this is wrong, because variational Bayes should regularize the tendency of an empty cluster to take on a degenerate value -- that's the role of the prior. We could fix it by reconciling the code with Chapter 10 of Bishop (at the cost of several days of dragon-slaying), or we could just re-randomly initialize sigma for an empty cluster, but I would like to the fix the problem at its source so that we avoid creating a pathological Gaussian int he first place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the basic idea of the Bishop approach of dealing with the singularity (mentioned at the end of 10.2.1), but I don't have any confidence in being able to work out the new closed-form solution. Maybe I can get some help offline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a number of reasons that I agree it's just standard Bishop. To fully generalize, I think all that has to happen is to uncouple the actual data from the rest of the VariantDatum object, which I can probably do in a not unreasonable amount of time. I think I modified that class in the past. I might also be able to fully uncouple the model training from the inference, which I did in a crufty way in the past. And improve on the memory requirements...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not yet convinced that our implementation is faithful to the variational Bayes GMM. Bishop 10.6 factorizes down to a joint distribution on mu and sigma, but it doesn't split any further, for one thing. And even if it did, our GMM has a point estimate of sigma for each cluster, which is not the same as a VB M step that computes a posterior. Finally, I couldn't find anything in Bishop that resembles our empiricalSigma, which is the covariance of all the data.

Note that one might worry that Bishop's joint posterior on mu and sigma -- a Gauss-Wishart distribution -- would be intractable for VQSR because the predictive densities could not be computed, but this is not the case. As Bishop shows, this is just a Student's t.

This is not to say that I think we should go after this dragon, but I think it's worth settling whether our implementation is actually a true VB GMM.

Copy link
Contributor

@samuelklee samuelklee Jan 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't empiricalMu and empiricalSigma determine the respective priors? From a glance, it looks like they are initialized to the origin and some magic number times the identity matrix, respectively, and then never updated. You might be right about the implementation not being exactly faithful, but I think I concluded that it probably attempted to be.

I think that the difficulty of maintaining or improving this code arises because 1) the intended implementation is not clearly documented and 2) the code is not documented or easily parsed as the corresponding mathematical expressions. This is further compounded by the need for the relatively trivial extraction and refactoring mentioned previously.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point estimate you mention (i.e., the one computed by evaluateFinalModelParameters) seems to be somewhat arbitrarily chosen, but I think that the updates for the variational parameters calculated in maximizeGaussian do seem to follow 10.58 and 10.60-63.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting for future reference that I went digging into the original GATK3 code (the commit history over there seems corrupted...not sure if something went wrong at some point, perhaps during a public/protected split?). It looks like empiricalMu and empiricalSigma were indeed originally initialized to the mean and covariance of the data (although the code for the latter was commented out), although I'm not sure what the justification for this was. In broadgsa/gatk@3224bbe these were changed to represent the quantities for the Bishop prior, but the variable names were not changed.

@@ -406,6 +406,8 @@ public void testVariantRecalibratorSNPscattered(final String[] params) throws IO
doSNPTest(params, getLargeVQSRTestDataDir() + "/snpTranches.scattered.txt", getLargeVQSRTestDataDir() + "snpRecal.vcf"); //tranches file isn't in the expected/ directory because it's input to GatherTranchesIntegrationTest
}

//One of the 8 positive Gaussians (index 4) has weight zero, which is fine, but the one at index 2 has a covariance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything in the data or arguments from which we expect a priori that these errors will occur, and at the given indices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the input data AS_MQrankSum is constant, which I note where the argument string is defined, but I could make that more clear. Specifying the index of the Gaussian probably isn't informative since anything that touches the random number generator has the potential to change that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case a comment here that the constantness (constancy?) of AS_MQrankSum causes one Gaussian to have a degenerate covariance would be helpful. I would either not even mention that it's the one with index 2 or be explicit that this is arbitrary (likewise for the zero-weight Gaussian at index 4).

@davidbenjamin davidbenjamin self-assigned this Jan 30, 2020
one or more Gaussians to zero (hence *MAX* Gaussians), but this causes
problems in log10 space
@ldgauthier ldgauthier force-pushed the ldg_debugPositiveModel branch from 67fd85d to f54aafd Compare February 4, 2020 19:55
@ldgauthier
Copy link
Contributor Author

@davidbenjamin removing the final Mish-step in "evaluateFinalModelParameters" like we talked about made a huge difference in the VQSLODs in the test data. I'm going to continue looking into it, but I really want to get this fix for production into this week's release.

@davidbenjamin
Copy link
Contributor

@ldgauthier I can live with that. 👍 to the quick fix.

@ldgauthier
Copy link
Contributor Author

@davidbenjamin can I get a review +1 so I can get this into the release?

@davidbenjamin
Copy link
Contributor

Yes.

@ldgauthier ldgauthier merged commit d3cd0cc into master Feb 7, 2020
@ldgauthier ldgauthier deleted the ldg_debugPositiveModel branch February 7, 2020 15:33
@samuelklee
Copy link
Contributor

@ldgauthier @davidbenjamin any decision about what further improvements are worth pursuing? Looking for some quick and easy things to look into now that I'm back.

In any case, are there any baseline runtime/memory estimates? How about typical numbers of variants and annotation dimensions that the tool is run with? Just curious about a quick comparison with sklearn, etc.

@ldgauthier
Copy link
Contributor Author

It's on my list. Pretty near the bottom, but it's there.

Runtime is probably getting up near an hour for big jobs. The memory requirements are horrific, because we load all the variants into memory and then we don't even use them all! For the biggest cohorts we use 104GB. I wish I was joking. If sklearn can minibatch GMMs then that would be amazing. We use a maximum of 2.5M variants for training and number of annotations/dimensions is O(10). The smallest exome cohort would train with about 80,000 variants with about 3GB of memory.

That being said this definitely isn't the biggest cost contributor for joint calling, and hopefully all the sporadic failures have been hammered out.

@samuelklee
Copy link
Contributor

Did a quick test with sklearn's BayesianGaussianMixture, fitting 8 components to 2.5M 10D points generated from 4 isotropic blobs. On a Google Colab instance (which I believe are n1-highmem-2s), 150 iterations (which I think is the current maximum) completed in 14 minutes, with %memit reporting a memory peak of ~1.5GB.

Note that convergence within the default tolerance isn't actually reached in 150 iterations for this toy data (as usual, it takes a while for the weights of unused components to shrink to zero). In any case, we'd have to compare against the number of iterations currently required to converge with the real data (and perhaps also check that the convergence criteria match up) to get a better idea of real runtime. Various tweaks to priors or other runtime options (such as k-means vs. random initialization) could also affect convergence speed.

Minibatching isn't built in, but I think it should be pretty trivial to hack together something with the warm_start option; we could probably just do a warm start with a subset of the data. See also scikit-learn/scikit-learn#9334.

@samuelklee
Copy link
Contributor

samuelklee commented Mar 4, 2020

Doing a warm start with 1% of the toy data followed by a full fit takes <30s from start to finish and convergence is reached on the full fit in less than 10 iterations. We'd probably have to do some testing on real data to make sure we don't lose any important clusters with low weight using such a strategy (or some variation thereof), but in any case I think we can come in well under 1hr and 104GB...

@davidbenjamin
Copy link
Contributor

That all sounds really good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants