Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplified genotype likelihood calculation (no change in output) #6351

Merged
merged 1 commit into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package org.broadinstitute.hellbender.tools.copynumber.models;

import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.apache.commons.math3.util.FastMath;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;

import java.util.List;
import java.util.stream.IntStream;

import static org.apache.commons.math3.util.FastMath.sqrt;
import static org.broadinstitute.hellbender.utils.MathUtils.log10Factorial;
import static org.broadinstitute.hellbender.utils.MathUtils.log10ToLog;

/**
* Contains likelihood methods for the allele-fraction model.
Expand Down Expand Up @@ -87,10 +86,7 @@ static double hetLogLikelihood(final AlleleFractionGlobalParameters parameters,
- n * log(majorFraction + minorFraction * lambda0RefMinor);
final double refMinorLogLikelihood = logNotPi + logcRefMinor + Gamma.logGamma(rhoRefMinor) - rhoRefMinor * log(tauRefMinor);

// changing the factorial implementation below may introduce non-negligible numerical differences;
// note https://github.com/broadinstitute/gatk/pull/7652
final double outlierLogLikelihood = logPi + log10ToLog(log10Factorial(a) + log10Factorial(r) - log10Factorial(a + r + 1));

final double outlierLogLikelihood = logPi - Math.log(a + r + 1) - CombinatoricsUtils.binomialCoefficientLog(a+r,a);
return NaturalLogUtils.logSumExp(altMinorLogLikelihood, refMinorLogLikelihood, outlierLogLikelihood);
}

Expand Down Expand Up @@ -165,6 +161,6 @@ private static double biasPosteriorEffectiveBeta(final double lambda0, final dou
}

private static double log(final double x) {
return FastMath.log(Math.max(EPSILON, x));
return Math.log(Math.max(EPSILON, x));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ public CopyRatioState.OutlierIndicators sample(final RandomGenerator rng,
final CopyRatioSegmentedData data) {
logger.debug("Sampling outlier indicators...");
final double outlierUnnormalizedLogProbability =
FastMath.log(state.outlierProbability()) + outlierUniformLogLikelihood;
Math.log(state.outlierProbability()) + outlierUniformLogLikelihood;
// final double notOutlierUnnormalizedLogProbabilityPrefactor =
// FastMath.log(1. - state.outlierProbability()) - 0.5 * FastMath.log(2 * Math.PI * state.variance());
// Math.log(1. - state.outlierProbability()) - 0.5 * Math.log(2 * Math.PI * state.variance());
final double notOutlierUnnormalizedLogProbabilityPrefactor =
FastMath.log((1. - state.outlierProbability()) / FastMath.sqrt(2 * Math.PI * state.variance()));
Math.log((1. - state.outlierProbability()) / FastMath.sqrt(2 * Math.PI * state.variance()));
final List<Boolean> indicators = new ArrayList<>(data.getNumPoints());
for (int segmentIndex = 0; segmentIndex < data.getNumSegments(); segmentIndex++) {
final List<CopyRatioSegmentedData.IndexedCopyRatio> indexedCopyRatiosInSegment = data.getIndexedCopyRatiosInSegment(segmentIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,6 @@ private static double calculateHomozygousLogRatio(final AllelicCount allelicCoun
final double betaOneMinusError = Beta.regularizedBeta(1 - genotypingBaseErrorRate, r + 1, n - r + 1);
final double betaHom = betaError + betaAll - betaOneMinusError;
final double betaHet = betaOneMinusError - betaError;
return FastMath.log(betaHom) - FastMath.log(betaHet);
return Math.log(betaHom) - Math.log(betaHet);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine;
import org.broadinstitute.hellbender.tools.walkers.annotator.allelespecific.AlleleSpecificAnnotationData;
import org.broadinstitute.hellbender.tools.walkers.annotator.allelespecific.ReducibleAnnotationData;
import org.broadinstitute.hellbender.tools.walkers.genotyper.AlleleSubsettingUtils;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeAssignmentMethod;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculators;
import org.broadinstitute.hellbender.tools.walkers.genotyper.*;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2FilteringEngine;
import org.broadinstitute.hellbender.utils.GenotypeUtils;
import org.broadinstitute.hellbender.utils.Utils;
Expand All @@ -34,7 +32,6 @@
@SuppressWarnings({"rawtypes","unchecked"}) //TODO fix uses of untyped Comparable.
public final class ReferenceConfidenceVariantContextMerger {

private static final GenotypeLikelihoodCalculators calculators = new GenotypeLikelihoodCalculators();
private static VCFHeader vcfInputHeader = null;
protected final VariantAnnotatorEngine annotatorEngine;
private final boolean doSomaticMerge;
Expand Down Expand Up @@ -571,7 +568,6 @@ private GenotypesContext mergeRefConfidenceGenotypes(final VariantContext vc,
// the map is different depending on the ploidy, so in order to keep this method flexible (mixed ploidies)
// we need to get a map done (lazily inside the loop) for each ploidy, up to the maximum possible.
final int[][] genotypeIndexMapsByPloidy = new int[maximumPloidy + 1][];
final int maximumAlleleCount = Math.max(remappedAlleles.size(),targetAlleles.size());

for ( final Genotype g : vc.getGenotypes() ) {
final String name;
Expand All @@ -584,20 +580,17 @@ private GenotypesContext mergeRefConfidenceGenotypes(final VariantContext vc,
final GenotypeBuilder genotypeBuilder = new GenotypeBuilder(g);
if (!doSomaticMerge) {
if (g.hasPL() || g.hasAD()) {
int[] perSampleIndexesOfRelevantAlleles = AlleleSubsettingUtils.getIndexesOfRelevantAllelesForGVCF(remappedAlleles, targetAlleles, vc.getStart(), g, false);
int[] perSampleIndexesOfRelevantAlleles = AlleleSubsettingUtils.getIndexesOfRelevantAllelesForGVCF(remappedAlleles, targetAlleles, vc.getStart(), g, false);
if (g.hasPL()) {
// lazy initialization of the genotype index map by ploidy.
final int[] genotypeIndexMapByPloidy = genotypeIndexMapsByPloidy[ploidy] == null
? calculators.getInstance(ploidy, maximumAlleleCount).genotypeIndexMap(perSampleIndexesOfRelevantAlleles, calculators) //probably horribly slow
? GenotypeIndexCalculator.newToOldGenotypeMap(ploidy, perSampleIndexesOfRelevantAlleles) //probably horribly slow
: genotypeIndexMapsByPloidy[ploidy];
final int[] PLs = generatePL(g, genotypeIndexMapByPloidy);
genotypeBuilder.PL(PLs);
genotypeBuilder.PL(generatePL(g, genotypeIndexMapByPloidy));
}
if (g.hasAD()) {
final int[] AD = AlleleSubsettingUtils.generateAD(g.getAD(), perSampleIndexesOfRelevantAlleles);
genotypeBuilder.AD(AD);
genotypeBuilder.AD(AlleleSubsettingUtils.generateAD(g.getAD(), perSampleIndexesOfRelevantAlleles));
}
// clean up low confidence hom refs for better annotations later
//clean up low confidence hom refs for better annotations later
} else if (GenotypeGVCFsEngine.excludeFromAnnotations(g)) {
genotypeBuilder.alleles(Collections.nCopies(ploidy, Allele.NO_CALL));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ private static double probability(final PileupSummary site, final double contami
}

private static double segmentLogLikelihood(final List<PileupSummary> segment, final double contamination, final double errorRate, final double minorAlleleFraction) {
return segment.stream().mapToDouble(site -> FastMath.log(MathUtils.sum(genotypeLikelihoods(site, contamination, errorRate, minorAlleleFraction)))).sum();
return segment.stream().mapToDouble(site -> Math.log(MathUtils.sum(genotypeLikelihoods(site, contamination, errorRate, minorAlleleFraction)))).sum();
}

private static double modelLogLikelihood(final List<List<PileupSummary>> segments, final double contamination, final double errorRate, final List<Double> mafs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import htsjdk.variant.variantcontext.*;
import htsjdk.variant.vcf.*;
import htsjdk.variant.vcf.VCFConstants;
import htsjdk.variant.vcf.VCFFormatHeaderLine;
import htsjdk.variant.vcf.VCFHeader;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.annotator.AnnotationUtils;
import org.broadinstitute.hellbender.utils.genotyper.GenotypePriorCalculator;
import org.broadinstitute.hellbender.tools.walkers.ReferenceConfidenceVariantContextMerger;
import org.broadinstitute.hellbender.tools.walkers.annotator.AnnotationUtils;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.collections.Permutation;
import org.broadinstitute.hellbender.utils.genotyper.GenotypePriorCalculator;
import org.broadinstitute.hellbender.utils.genotyper.IndexedAlleleList;
import org.broadinstitute.hellbender.utils.logging.OneShotLogger;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
Expand All @@ -37,8 +39,6 @@ private AlleleSubsettingUtils() {} // prevent instantiation

private static final OneShotLogger attributesRemovedOneShotLogger = new OneShotLogger(AlleleSubsettingUtils.class);

private static final GenotypeLikelihoodCalculators GL_CALCS = new GenotypeLikelihoodCalculators();

public static GenotypesContext subsetAlleles(final GenotypesContext originalGs, final int defaultPloidy,
final List<Allele> originalAlleles,
final List<Allele> allelesToKeep,
Expand All @@ -47,6 +47,7 @@ public static GenotypesContext subsetAlleles(final GenotypesContext originalGs,
//TODO: if other usages of this method should update or remove A,R, or G length annotations then header parsing is necessary and the method below should be used
return subsetAlleles(originalGs, defaultPloidy, originalAlleles, allelesToKeep, gpc, assignmentMethod, Collections.emptyList());
}

/**
* Create the new GenotypesContext with the subsetted PLs and ADs
*
Expand Down Expand Up @@ -399,12 +400,10 @@ static double[] calculateLikelihoodSums(final VariantContext vc, final int defau
final double GLDiffBetweenRefAndBestVariantGenotype = Math.abs(glsVector[indexOfMostLikelyVariantGenotype] - glsVector[PL_INDEX_OF_HOM_REF]);
final int ploidy = genotype.getPloidy() > 0 ? genotype.getPloidy() : defaultPloidy;

final int[] alleleCounts = new GenotypeLikelihoodCalculators()
.getInstance(ploidy, vc.getNAlleles()).genotypeAlleleCountsAt(indexOfMostLikelyVariantGenotype)
.alleleCountsByIndex(vc.getNAlleles() - 1);
final GenotypeAlleleCounts mostLikelyGenotypeAlleleCounts = GenotypesCache.get(ploidy, indexOfMostLikelyVariantGenotype);

for (int allele = 1; allele < alleleCounts.length; allele++) {
if (alleleCounts[allele] > 0) {
for (int allele = 1; allele < vc.getNAlleles(); allele++) {
if (mostLikelyGenotypeAlleleCounts.containsAllele(allele)) {
likelihoodSums[allele] += GLDiffBetweenRefAndBestVariantGenotype;
}
}
Expand All @@ -428,10 +427,7 @@ public static int[] subsettedPLIndices(final int ploidy, final List<Allele> orig
final int[] result = new int[GenotypeLikelihoods.numLikelihoods(newAlleles.size(), ploidy)];
final Permutation<Allele> allelePermutation = new IndexedAlleleList<>(originalAlleles).permutation(new IndexedAlleleList<>(newAlleles));

final GenotypeLikelihoodCalculator glCalc = GL_CALCS.getInstance(ploidy, originalAlleles.size());
for (int oldPLIndex = 0; oldPLIndex < glCalc.genotypeCount(); oldPLIndex++) {
final GenotypeAlleleCounts oldAlleleCounts = glCalc.genotypeAlleleCountsAt(oldPLIndex);

for (final GenotypeAlleleCounts oldAlleleCounts : GenotypeAlleleCounts.iterable(ploidy, originalAlleles.size())) {
final boolean containsOnlyNewAlleles = IntStream.range(0, oldAlleleCounts.distinctAlleleCount())
.map(oldAlleleCounts::alleleIndexAt).allMatch(allelePermutation::isKept);

Expand All @@ -441,8 +437,8 @@ public static int[] subsettedPLIndices(final int ploidy, final List<Allele> orig
final int[] newAlleleCounts = IntStream.range(0, newAlleles.size()).flatMap(newAlleleIndex ->
IntStream.of(newAlleleIndex, oldAlleleCounts.alleleCountFor(allelePermutation.fromIndex(newAlleleIndex)))).toArray();

final int newPLIndex = glCalc.alleleCountsToIndex(newAlleleCounts);
result[newPLIndex] = oldPLIndex;
final int newPLIndex = GenotypeIndexCalculator.alleleCountsToIndex(newAlleleCounts);
result[newPLIndex] = oldAlleleCounts.index();
}
}
return result;
Expand Down Expand Up @@ -492,39 +488,6 @@ public static int[] getIndexesOfRelevantAllelesForGVCF(final List<Allele> remapp
return indexMapping;
}

public static int[] getIndexesOfRelevantAlleles(final List<Allele> remappedAlleles, final List<Allele> targetAlleles, final int position, final Genotype g) {
Utils.nonEmpty(remappedAlleles);
Utils.nonEmpty(targetAlleles);

final int[] indexMapping = new int[targetAlleles.size()];

// the reference likelihoods should always map to each other (even if the alleles don't)
indexMapping[0] = 0;

for ( int i = 1; i < targetAlleles.size(); i++ ) {
// if there's more than 1 spanning deletion (*) allele then we need to use the best one
if (targetAlleles.get(i) == Allele.SPAN_DEL && g.hasPL()) {
final int occurrences = Collections.frequency(remappedAlleles, Allele.SPAN_DEL);
if (occurrences > 1) {
final int indexOfBestDel = indexOfBestDel(remappedAlleles, g.getPL(), g.getPloidy());
if (indexOfBestDel == -1) {
throw new IllegalArgumentException("At position " + position + " targetAlleles contains a spanning deletion, but remappedAlleles does not.");
}
indexMapping[i] = indexOfBestDel;
continue;
}
}

final int indexOfRemappedAllele = remappedAlleles.indexOf(targetAlleles.get(i));
if (indexOfRemappedAllele == -1) {
throw new IllegalArgumentException("At position " + position + " targetAlleles contains a " + targetAlleles.get(i) + " allele, but remappedAlleles does not.");
}
indexMapping[i] = indexOfRemappedAllele;
}

return indexMapping;
}

/**
* Returns the index of the best spanning deletion allele based on AD counts
*
Expand All @@ -539,7 +502,8 @@ private static int indexOfBestDel(final List<Allele> alleles, final int[] PLs, f

for ( int i = 0; i < alleles.size(); i++ ) {
if ( alleles.get(i) == Allele.SPAN_DEL ) {
final int homAltIndex = findHomIndex(GL_CALCS.getInstance(ploidy, alleles.size()), i, ploidy);
//In the canonical order, the homozygous genotype of the ith allele is immediately followed by the first genotype containing the (i+1)th allele.
final int homAltIndex = (int) GenotypeIndexCalculator.indexOfFirstGenotypeWithAllele(ploidy, i +1) - 1;
final int PL = PLs[homAltIndex];
if ( PL < bestPL ) {
bestIndex = i;
Expand All @@ -551,25 +515,6 @@ private static int indexOfBestDel(final List<Allele> alleles, final int[] PLs, f
return bestIndex;
}

/** //TODO simplify these methods
* Returns the index of the PL that represents the homozygous genotype of the given i'th allele
*
* @param i the index of the allele with the list of alleles
* @param ploidy the ploidy of the sample
* @return the hom index
*/
private static int findHomIndex(final GenotypeLikelihoodCalculator calculator, final int i, final int ploidy) {
// some quick optimizations for the common case
if ( ploidy == 2 )
return GenotypeLikelihoods.calculatePLindex(i, i);
if ( ploidy == 1 )
return i;

final int[] alleleIndexes = new int[ploidy];
Arrays.fill(alleleIndexes, i);
return calculator.allelesToIndex(alleleIndexes);
}

/**
* Generates a new AD array by adding zeros for missing alleles given the set of indexes of the Genotype's current
* alleles from the original AD.
Expand Down
Loading