Skip to content

Commit

Permalink
Optimizations for GenotypeGVCFs + porting synchronized caches from ga…
Browse files Browse the repository at this point in the history
…tk3 (#1957)
  • Loading branch information
akiezun authored and lbergelson committed Jul 6, 2016
1 parent 66081a4 commit 575984f
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ private int alleleHeapToIndex() {
*
* @return never {@code null}.
*/
public int[] genotypeIndexMap(final int[] oldToNewAlleleIndexMap) {
public int[] genotypeIndexMap(final int[] oldToNewAlleleIndexMap, final GenotypeLikelihoodCalculators calculators) {
if (oldToNewAlleleIndexMap == null) {
throw new IllegalArgumentException("the input encoding array cannot be null");
}
Expand All @@ -547,7 +547,7 @@ public int[] genotypeIndexMap(final int[] oldToNewAlleleIndexMap) {
+ resultAlleleCount + " alleles ");
}
final int resultLength = resultAlleleCount == alleleCount
? genotypeCount : new GenotypeLikelihoodCalculators().genotypeCount(ploidy,resultAlleleCount);
? genotypeCount : calculators.genotypeCount(ploidy,resultAlleleCount);

final int[] result = new int[resultLength];
final int[] sortedAlleleCounts = new int[Math.max(ploidy, alleleCount) << 1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.Arrays;

Expand All @@ -15,6 +17,8 @@
*/
public final class GenotypeLikelihoodCalculators {

private static final Logger logger = LogManager.getLogger(GenotypeLikelihoodCalculators.class);

/**
* The current maximum ploidy supported by the tables.
* <p>
Expand Down Expand Up @@ -62,6 +66,10 @@ public final class GenotypeLikelihoodCalculators {
private GenotypeAlleleCounts[][] genotypeTableByPloidy =
buildGenotypeAlleleCountsTable(maximumPloidy,maximumAllele,alleleFirstGenotypeOffsetByPloidy);

public GenotypeLikelihoodCalculators(){

}

/**
* Build the table with the genotype offsets based on ploidy and the maximum allele index with representation
* in the genotype.
Expand Down Expand Up @@ -283,6 +291,8 @@ private void ensureCapacity(final int requestedMaximumAllele, final int requeste
final int newMaximumPloidy = Math.max(maximumPloidy, requestedMaximumPloidy);
final int newMaximumAllele = Math.max(maximumAllele, requestedMaximumAllele);

logger.debug("Expanding capacity ploidy:" + maximumPloidy + "->" + newMaximumPloidy + " allele:" + maximumAllele +"->" + newMaximumAllele );

// Update tables first.
alleleFirstGenotypeOffsetByPloidy = buildAlleleFirstGenotypeOffsetTable(newMaximumPloidy,newMaximumAllele);
genotypeTableByPloidy = buildGenotypeAlleleCountsTable(newMaximumPloidy,newMaximumAllele,alleleFirstGenotypeOffsetByPloidy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public final class IndependentSampleGenotypesModel {
private final int cacheAlleleCountCapacity;
private final int cachePloidyCapacity;
private GenotypeLikelihoodCalculator[][] likelihoodCalculators;
private final GenotypeLikelihoodCalculators calculators;

public IndependentSampleGenotypesModel() { this(DEFAULT_CACHE_PLOIDY_CAPACITY, DEFAULT_CACHE_ALLELE_CAPACITY); }

Expand All @@ -31,6 +32,7 @@ public IndependentSampleGenotypesModel(final int calculatorCachePloidyCapacity,
cachePloidyCapacity = calculatorCachePloidyCapacity;
cacheAlleleCountCapacity = calculatorCacheAlleleCapacity;
likelihoodCalculators = new GenotypeLikelihoodCalculator[calculatorCachePloidyCapacity][calculatorCacheAlleleCapacity];
calculators = new GenotypeLikelihoodCalculators();
}

public <A extends Allele> GenotypingLikelihoods<A> calculateLikelihoods(final AlleleList<A> genotypingAlleles, final GenotypingData<A> data) {
Expand Down Expand Up @@ -62,13 +64,13 @@ public <A extends Allele> GenotypingLikelihoods<A> calculateLikelihoods(final Al

private GenotypeLikelihoodCalculator getLikelihoodsCalculator(final int samplePloidy, final int alleleCount) {
if (samplePloidy >= cachePloidyCapacity || alleleCount >= cacheAlleleCountCapacity) {
return new GenotypeLikelihoodCalculators().getInstance(samplePloidy, alleleCount);
return calculators.getInstance(samplePloidy, alleleCount);
}
final GenotypeLikelihoodCalculator result = likelihoodCalculators[samplePloidy][alleleCount];
if (result != null) {
return result;
} else {
final GenotypeLikelihoodCalculator newOne = new GenotypeLikelihoodCalculators().getInstance(samplePloidy, alleleCount);
final GenotypeLikelihoodCalculator newOne = calculators.getInstance(samplePloidy, alleleCount);
likelihoodCalculators[samplePloidy][alleleCount] = newOne;
return newOne;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public final class GeneralPloidyExactAFCalculator extends ExactAFCalculator {
static final int MAX_LENGTH_FOR_POOL_PL_LOGGING = 100; // if PL vectors longer than this # of elements, don't log them

private static final boolean VERBOSE = false;
private final GenotypeLikelihoodCalculators calculators = new GenotypeLikelihoodCalculators();

@Override
protected GenotypesContext reduceScopeGenotypes(final VariantContext vc, final int defaultPloidy, final List<Allele> allelesToUse) {
Expand Down Expand Up @@ -462,15 +463,15 @@ public GenotypesContext subsetAlleles(final VariantContext vc,
* @param allelesToUse the list of alleles to choose from (corresponding to the PLs)
* @param numChromosomes Number of chromosomes per pool
*/
private static void assignGenotype(final GenotypeBuilder gb,
private void assignGenotype(final GenotypeBuilder gb,
final double[] newLikelihoods,
final List<Allele> allelesToUse,
final int numChromosomes) {
final int numNewAltAlleles = allelesToUse.size() - 1;

// find the genotype with maximum likelihoods
final int PLindex = numNewAltAlleles == 0 ? 0 : MathUtils.maxElementIndex(newLikelihoods);
final GenotypeLikelihoodCalculator calculator = new GenotypeLikelihoodCalculators().getInstance(numChromosomes, allelesToUse.size());
final GenotypeLikelihoodCalculator calculator = calculators.getInstance(numChromosomes, allelesToUse.size());
final GenotypeAlleleCounts alleleCounts = calculator.genotypeAlleleCountsAt(PLindex);

gb.alleles(alleleCounts.asAlleleList(allelesToUse));
Expand Down Expand Up @@ -706,9 +707,8 @@ private static int getLinearIndex(final int[] vectorIdx, final int numAlleles, f
* @param PLindex Index to query
* @return Allele count conformation, according to iteration order from SumIterator
*/
private static int[] getAlleleCountFromPLIndex(final int nAlleles, final int numChromosomes, final int PLindex) {

final GenotypeLikelihoodCalculator calculator = new GenotypeLikelihoodCalculators().getInstance(numChromosomes, nAlleles);
private int[] getAlleleCountFromPLIndex(final int nAlleles, final int numChromosomes, final int PLindex) {
final GenotypeLikelihoodCalculator calculator = calculators.getInstance(numChromosomes, nAlleles);
final GenotypeAlleleCounts alleleCounts = calculator.genotypeAlleleCountsAt(PLindex);
return alleleCounts.alleleCountsByIndex(nAlleles - 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ public final class IndependentAllelesDiploidExactAFCalculator extends DiploidExa
* The AFCalc model we are using to do the bi-allelic computation
*/
private final AFCalculator biAlleleExactModel;
private final GenotypeLikelihoodCalculators calculators;

IndependentAllelesDiploidExactAFCalculator() {
biAlleleExactModel = new ReferenceDiploidExactAFCalculator();
calculators = new GenotypeLikelihoodCalculators();
}

@Override
Expand Down Expand Up @@ -113,7 +115,7 @@ private AFCalculationResult combineAltAlleleIndependentExact(final VariantContex
return biAlleleExactModel.getLog10PNonRef(combinedAltAllelesVariantContext, defaultPloidy, vc.getNAlleles() - 1, log10AlleleFrequencyPriors);
}

private static VariantContext makeCombinedAltAllelesVariantContext(final VariantContext vc) {
private VariantContext makeCombinedAltAllelesVariantContext(final VariantContext vc) {
final int nAltAlleles = vc.getNAlleles() - 1;

if ( nAltAlleles == 1 ) {
Expand All @@ -122,7 +124,7 @@ private static VariantContext makeCombinedAltAllelesVariantContext(final Variant
final VariantContextBuilder vcb = new VariantContextBuilder(vc);
final Allele reference = vcb.getAlleles().get(0);
vcb.alleles(Arrays.asList(reference, GATKVCFConstants.NON_REF_SYMBOLIC_ALLELE));
final int genotypeCount = new GenotypeLikelihoodCalculators().genotypeCount(2, vc.getNAlleles());
final int genotypeCount = calculators.genotypeCount(2, vc.getNAlleles());
final double[] hetLikelihoods = new double[vc.getNAlleles() - 1];
final double[] homAltLikelihoods = new double[genotypeCount - hetLikelihoods.length - 1];
final double[] newLikelihoods = new double[3];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,20 @@ private double calculateAlleleCountConformation(final ExactACset set,

private static void computeLofK(final ExactACset set,
final List<double[]> genotypeLikelihoods,
final double[] log10AlleleFrequencyPriors, final StateTracker stateTracker) {
final double[] log10AlleleFrequencyPriors,
final StateTracker stateTracker) {

set.getLog10Likelihoods()[0] = 0.0; // the zero case
final double[] setLog10Likelihoods = set.getLog10Likelihoods();
setLog10Likelihoods[0] = 0.0; // the zero case
final int totalK = set.getACsum();

// special case for k = 0 over all k
if ( totalK == 0 ) {
for ( int j = 1; j < set.getLog10Likelihoods().length; j++ ) {
set.getLog10Likelihoods()[j] = set.getLog10Likelihoods()[j - 1] + genotypeLikelihoods.get(j)[HOM_REF_INDEX];
for (int j = 1, n = setLog10Likelihoods.length; j < n; j++ ) {
setLog10Likelihoods[j] = setLog10Likelihoods[j - 1] + genotypeLikelihoods.get(j)[HOM_REF_INDEX];
}

final double log10Lof0 = set.getLog10Likelihoods()[set.getLog10Likelihoods().length-1];
final double log10Lof0 = setLog10Likelihoods[setLog10Likelihoods.length-1];
stateTracker.setLog10LikelihoodOfAFzero(log10Lof0);
stateTracker.setLog10PosteriorOfAFzero(log10Lof0 + log10AlleleFrequencyPriors[0]);
return;
Expand All @@ -146,18 +148,18 @@ private static void computeLofK(final ExactACset set,
// if we got here, then k > 0 for at least one k.
// the non-AA possible conformations were already dealt with by pushes from dependent sets;
// now deal with the AA case (which depends on previous cells in this column) and then update the L(j,k) value
for ( int j = 1; j < set.getLog10Likelihoods().length; j++ ) {
for (int j = 1, n = setLog10Likelihoods.length; j < n; j++ ) {
if ( totalK < 2*j-1 ) {
final double[] gl = genotypeLikelihoods.get(j);
final double conformationValue = MathUtils.log10(2*j-totalK) + MathUtils.log10(2*j-totalK-1) + set.getLog10Likelihoods()[j-1] + gl[HOM_REF_INDEX];
set.getLog10Likelihoods()[j] = MathUtils.approximateLog10SumLog10(set.getLog10Likelihoods()[j], conformationValue);
final double conformationValue = MathUtils.log10(2*j-totalK) + MathUtils.log10(2*j-totalK-1) + setLog10Likelihoods[j-1] + gl[HOM_REF_INDEX];
setLog10Likelihoods[j] = MathUtils.approximateLog10SumLog10(setLog10Likelihoods[j], conformationValue);
}

final double logDenominator = MathUtils.log10(2*j) + MathUtils.log10(2*j-1);
set.getLog10Likelihoods()[j] = set.getLog10Likelihoods()[j] - logDenominator;
setLog10Likelihoods[j] = setLog10Likelihoods[j] - logDenominator;
}

double log10LofK = set.getLog10Likelihoods()[set.getLog10Likelihoods().length-1];
double log10LofK = setLog10Likelihoods[setLog10Likelihoods.length-1];

// update the MLE if necessary
stateTracker.updateMLEifNeeded(log10LofK, set.getACcounts().getCounts());
Expand Down Expand Up @@ -209,12 +211,16 @@ private static void pushData(final ExactACset targetSet,
final List<double[]> genotypeLikelihoods) {
final int totalK = targetSet.getACsum();

for ( int j = 1; j < targetSet.getLog10Likelihoods().length; j++ ) {
if ( totalK <= 2*j ) { // skip impossible conformations
final double[] targetSetLog10Likelihoods = targetSet.getLog10Likelihoods();
final double[] dependentSetLog10Likelihoods = dependentSet.getLog10Likelihoods();
final int[] counts = targetSet.getACcounts().getCounts();

for ( int j = 1, n = targetSetLog10Likelihoods.length; j < n; j++ ) {
if (2 * j >= totalK) { // skip impossible conformations
final double[] gl = genotypeLikelihoods.get(j);
final double conformationValue =
determineCoefficient(PLsetIndex, j, targetSet.getACcounts().getCounts(), totalK) + dependentSet.getLog10Likelihoods()[j-1] + gl[PLsetIndex];
targetSet.getLog10Likelihoods()[j] = MathUtils.approximateLog10SumLog10(targetSet.getLog10Likelihoods()[j], conformationValue);
determineCoefficient(PLsetIndex, j, counts, totalK) + dependentSetLog10Likelihoods[j-1] + gl[PLsetIndex];
targetSetLog10Likelihoods[j] = MathUtils.approximateLog10SumLog10(targetSetLog10Likelihoods[j], conformationValue);
}
}
}
Expand Down
Loading

0 comments on commit 575984f

Please sign in to comment.