Skip to content

Commit

Permalink
Sped Mutect2 reference confidence model with fast likelihoods model
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbenjamin committed Feb 18, 2020
1 parent 5887df8 commit e09f0d7
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.broadinstitute.hellbender.tools.walkers.mutect;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.util.Locatable;
import htsjdk.variant.variantcontext.Genotype;
Expand All @@ -10,9 +9,9 @@
import htsjdk.variant.vcf.VCFHeader;
import htsjdk.variant.vcf.VCFHeaderLine;
import htsjdk.variant.vcf.VCFStandardHeaderLines;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.lang3.mutable.MutableLong;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.util.FastMath;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -27,6 +26,7 @@
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.*;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.readthreading.ReadThreadingAssembler;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.FilterMutectCalls;
import org.broadinstitute.hellbender.tools.walkers.readorientation.BetaDistributionShape;
import org.broadinstitute.hellbender.tools.walkers.readorientation.F1R2CountsCollector;
import org.broadinstitute.hellbender.transformers.PalindromeArtifactClipReadTransformer;
import org.broadinstitute.hellbender.transformers.ReadTransformer;
Expand Down Expand Up @@ -142,7 +142,7 @@ public Mutect2Engine(final M2ArgumentCollection MTAC, AssemblyRegionArgumentColl
genotypingEngine = new SomaticGenotypingEngine(MTAC, normalSamples, annotationEngine);
haplotypeBAMWriter = AssemblyBasedCallerUtils.createBamWriter(MTAC, createBamOutIndex, createBamOutMD5, header);
trimmer = new AssemblyRegionTrimmer(assemblyRegionArgs, header.getSequenceDictionary());
referenceConfidenceModel = new SomaticReferenceConfidenceModel(samplesList, header, 0, genotypingEngine); //TODO: do something classier with the indel size arg
referenceConfidenceModel = new SomaticReferenceConfidenceModel(samplesList, header, 0, MTAC.minAF); //TODO: do something classier with the indel size arg
final List<String> tumorSamples = ReadUtils.getSamplesFromHeader(header).stream().filter(this::isTumorSample).collect(Collectors.toList());
f1R2CountsCollector = MTAC.f1r2TarGz == null ? Optional.empty() : Optional.of(new F1R2CountsCollector(MTAC.f1r2Args, header, MTAC.f1r2TarGz, tumorSamples));
}
Expand Down Expand Up @@ -475,20 +475,28 @@ private static List<Byte> altQuals(final ReadPileup pileup, final byte refBase,
return result;
}

private static double logLikelihoodRatio(final int refCount, final List<Byte> altQuals) {
public static double logLikelihoodRatio(final int refCount, final List<Byte> altQuals) {
return logLikelihoodRatio(refCount, altQuals, 1);
}

// this implements the isActive() algorithm described in docs/mutect/mutect.pdf
// the multiplicative factor is for the special case where we pass a singleton list
// of alt quals and want to duplicate that alt qual over multiple reads
@VisibleForTesting
static double logLikelihoodRatio(final int nRef, final List<Byte> altQuals, final int repeatFactor) {


/**
* this implements the isActive() algorithm described in docs/mutect/mutect.pdf
* the multiplicative factor is for the special case where we pass a singleton list
* of alt quals and want to duplicate that alt qual over multiple reads
* @param nRef ref read count
* @param altQuals Phred-scaled qualities of alt-supporting reads
* @param repeatFactor Number of times each alt qual is duplicated
* @param afPrior Beta prior on alt allele fraction
* @return
*/
public static double logLikelihoodRatio(final int nRef, final List<Byte> altQuals, final int repeatFactor, final Optional<BetaDistributionShape> afPrior) {
final int nAlt = repeatFactor * altQuals.size();
final int n = nRef + nAlt;

final double fTildeRatio = FastMath.exp(MathUtils.digamma(nRef + 1) - MathUtils.digamma(nAlt + 1));
final double betaEntropy = MathUtils.log10ToLog(-MathUtils.log10Factorial(n+1) + MathUtils.log10Factorial(nAlt) + MathUtils.log10Factorial(nRef));


double readSum = 0;
for (final byte qual : altQuals) {
Expand All @@ -499,8 +507,21 @@ static double logLikelihoodRatio(final int nRef, final List<Byte> altQuals, fina
readSum += zBarAlt * (logOneMinusEpsilon - logEpsilon) + MathUtils.fastBernoulliEntropy(zBarAlt);
}

final double betaEntropy;
if (afPrior.isPresent()) {
final double alpha = afPrior.get().getAlpha();
final double beta = afPrior.get().getBeta();
betaEntropy = Gamma.logGamma(alpha + beta) - Gamma.logGamma(alpha) - Gamma.logGamma(beta)
- Gamma.logGamma(alpha + beta + n) + Gamma.logGamma(alpha + nAlt) + Gamma.logGamma(beta + nRef);
} else {
betaEntropy = MathUtils.log10ToLog(-MathUtils.log10Factorial(n + 1) + MathUtils.log10Factorial(nAlt) + MathUtils.log10Factorial(nRef));
}
return betaEntropy + readSum * repeatFactor;
}

// the default case of a flat Beta(1,1) prior on allele fraction
public static double logLikelihoodRatio(final int nRef, final List<Byte> altQuals, final int repeatFactor) {
return logLikelihoodRatio(nRef, altQuals, repeatFactor, Optional.empty());
}

// same as above but with a constant error probability for several alts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,16 @@ public class SomaticGenotypingEngine {
final boolean hasNormal;
protected VariantAnnotatorEngine annotationEngine;

// If MTAC.minAF is non-zero we softly cut off allele fractions below minAF with a Beta prior of the form Beta(1+epsilon, 1); that is
// the prior on allele fraction f is proportional to f^epsilon. If epsilon is small this prior vanishes as f -> 0
// and very rapidly becomes flat. We choose epsilon such that minAF^epsilon = 0.5.
private final double refPseudocount = 1;
private final double altPseudocount;

public SomaticGenotypingEngine(final M2ArgumentCollection MTAC, final Set<String> normalSamples, final VariantAnnotatorEngine annotationEngine) {
this.MTAC = MTAC;
altPseudocount = MTAC.minAF == 0.0 ? 1 : 1 - Math.log(2)/Math.log(MTAC.minAF);

this.normalSamples = normalSamples;
hasNormal = !normalSamples.isEmpty();
this.annotationEngine = annotationEngine;
Expand Down Expand Up @@ -199,24 +207,27 @@ public CalledHaplotypes callMutations(
return new CalledHaplotypes(outputCallsWithEventCountAnnotation, calledHaplotypes);
}

private double[] makePriorPseudocounts(final int numAlleles) {
return new IndexRange(0, numAlleles).mapToDouble(n -> n == 0 ? refPseudocount : altPseudocount);
}

// compute the likelihoods that each allele is contained at some allele fraction in the sample
protected <EVIDENCE extends Locatable> PerAlleleCollection<Double> somaticLogOdds(final LikelihoodMatrix<EVIDENCE, Allele> logMatrix) {
final int alleleListEnd = logMatrix.alleles().size()-1;
final int nonRefIndex = logMatrix.alleles().contains(Allele.NON_REF_ALLELE)
&& logMatrix.alleles().get(alleleListEnd).equals(Allele.NON_REF_ALLELE) ? alleleListEnd : -1;
if (logMatrix.alleles().contains(Allele.NON_REF_ALLELE) && !(logMatrix.alleles().get(alleleListEnd).equals(Allele.NON_REF_ALLELE))) {
throw new IllegalStateException("<NON_REF> must be last in the allele list.");
}

final double logEvidenceWithAllAlleles = logMatrix.evidenceCount() == 0 ? 0 :
SomaticLikelihoodsEngine.logEvidence(getAsRealMatrix(logMatrix), MTAC.minAF, nonRefIndex);
SomaticLikelihoodsEngine.logEvidence(getAsRealMatrix(logMatrix), makePriorPseudocounts(logMatrix.numberOfAlleles()));

final PerAlleleCollection<Double> lods = new PerAlleleCollection<>(PerAlleleCollection.Type.ALT_ONLY);
final int refIndex = getRefIndex(logMatrix);
IntStream.range(0, logMatrix.numberOfAlleles()).filter(a -> a != refIndex).forEach( a -> {
final Allele allele = logMatrix.getAllele(a);
final LikelihoodMatrix<EVIDENCE, Allele> logMatrixWithoutThisAllele = SubsettedLikelihoodMatrix.excludingAllele(logMatrix, allele);
final double logEvidenceWithoutThisAllele = logMatrixWithoutThisAllele.evidenceCount() == 0 ? 0 :
SomaticLikelihoodsEngine.logEvidence(getAsRealMatrix(logMatrixWithoutThisAllele), MTAC.minAF, logMatrixWithoutThisAllele.numberOfAlleles() > 1 ? nonRefIndex-1 : -1); //nonRefIndex-1 because we're evaluating without one allele; if th
SomaticLikelihoodsEngine.logEvidence(getAsRealMatrix(logMatrixWithoutThisAllele), makePriorPseudocounts(logMatrixWithoutThisAllele.numberOfAlleles()));
lods.setAlt(allele, logEvidenceWithAllAlleles - logEvidenceWithoutThisAllele);
});
return lods;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.special.Beta;
import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.util.MathArrays;
import org.broadinstitute.hellbender.utils.*;
Expand Down Expand Up @@ -34,7 +33,7 @@ public static double[] alleleFractionsPosterior(final RealMatrix logLikelihoods,
// alleleCounts = \sum_r \bar{z}_r, where \bar{z}_r is an a-dimensional vector of the expectation of z_r with respect to q(f)
final double[] alleleCounts = getEffectiveCounts(logLikelihoods, dirichletPosterior);
final double[] newDirichletPosterior = MathArrays.ebeAdd(alleleCounts, priorPseudocounts);
converged = MathArrays.distance1(dirichletPosterior, newDirichletPosterior) < CONVERGENCE_THRESHOLD;
converged = MathArrays.distance1(dirichletPosterior, newDirichletPosterior)/MathUtils.sum(newDirichletPosterior) < CONVERGENCE_THRESHOLD;
dirichletPosterior = newDirichletPosterior;
}

Expand All @@ -55,32 +54,16 @@ protected static double[] getEffectiveCounts(RealMatrix logLikelihoods, double[]
read -> NaturalLogUtils.posteriors(effectiveLogWeights, logLikelihoods.getColumn(read)));
}

/**
* @param logLikelihoods matrix of alleles x reads
* @param priorPseudocounts
*/
public static double logEvidence(final RealMatrix logLikelihoods, final double[] priorPseudocounts) {
return logEvidence(logLikelihoods, priorPseudocounts, 0.0, -1);
}


/**
* @param logLikelihoods matrix of alleles x reads (NOTE: NON_REF allele is assumed to be last)
* @param priorPseudocounts
* @param alleleFractionThreshold lower bound of allele fractions to consider for non-ref likelihood
*/
public static double logEvidence(final RealMatrix logLikelihoods, final double[] priorPseudocounts, final double alleleFractionThreshold, final int nonRefIndex) {
public static double logEvidence(final RealMatrix logLikelihoods, final double[] priorPseudocounts) {
final int numberOfAlleles = logLikelihoods.getRowDimension();
Utils.validateArg(numberOfAlleles == priorPseudocounts.length, "Must have one pseudocount per allele.");
final double[] alleleFractionsPosterior = alleleFractionsPosterior(logLikelihoods, priorPseudocounts);
final double priorContribution = logDirichletNormalization(priorPseudocounts);
final double posteriorContribution = -logDirichletNormalization(alleleFractionsPosterior);
final double posteriorTotal = MathUtils.sum(alleleFractionsPosterior);
double thresholdedPosteriorContribution = posteriorContribution;
if (nonRefIndex > 0) {
thresholdedPosteriorContribution += Math.log(1-Beta.regularizedBeta(alleleFractionThreshold,
alleleFractionsPosterior[nonRefIndex], posteriorTotal - alleleFractionsPosterior[nonRefIndex]));
}

final double[] logAlleleFractions = new Dirichlet(alleleFractionsPosterior).effectiveLogMultinomialWeights();

Expand All @@ -91,7 +74,7 @@ public static double logEvidence(final RealMatrix logLikelihoods, final double[]
return likelihoodsContribution(logLikelihoodsForRead, responsibilities) - entropyContribution;
});

return priorContribution + thresholdedPosteriorContribution + likelihoodsAndEntropyContribution;
return priorContribution + posteriorContribution + likelihoodsAndEntropyContribution;
}

private static double likelihoodsContribution(final double[] logLikelihoodsForRead, final double[] responsibilities) {
Expand All @@ -106,13 +89,6 @@ private static double likelihoodsContribution(final double[] logLikelihoodsForRe
return result;
}


// same as above using the default flat prior
public static double logEvidence(final RealMatrix logLikelihoods, final double minAF, final int nonRefIndex) {
final double[] flatPrior = new IndexRange(0, logLikelihoods.getRowDimension()).mapToDouble(n -> 1);
return logEvidence(logLikelihoods, flatPrior, minAF, nonRefIndex);
}

private static double xLogx(final double x) {
return x < 1e-8 ? 0 : x * Math.log(x);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import htsjdk.variant.variantcontext.*;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReferenceConfidenceModel;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReferenceConfidenceResult;
import org.broadinstitute.hellbender.tools.walkers.readorientation.BetaDistributionShape;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.genotyper.AlleleLikelihoods;
import org.broadinstitute.hellbender.utils.genotyper.IndexedAlleleList;
import org.broadinstitute.hellbender.utils.genotyper.SampleList;
import org.broadinstitute.hellbender.utils.param.ParamUtils;
import org.broadinstitute.hellbender.utils.pileup.PileupElement;
import org.broadinstitute.hellbender.utils.pileup.ReadPileup;
import org.broadinstitute.hellbender.utils.read.GATKRead;
Expand All @@ -19,24 +21,27 @@
public class SomaticReferenceConfidenceModel extends ReferenceConfidenceModel {

private final SampleList samples;
private final SomaticGenotypingEngine genotypingEngine;
private final Optional<BetaDistributionShape> afPrior;



/**
* Create a new ReferenceConfidenceModel
*
* @param samples the list of all samples we'll be considering with this model
* @param samples the list of all samples we'll be considering with this model
* @param header the SAMFileHeader describing the read information (used for debugging)
* @param indelInformativeDepthIndelSize the max size of indels to consider when calculating indel informative depths
* @param minAF soft threshold for allele fractions -- above this value prior is nearly flat, below, prior is nearly zero
*/
SomaticReferenceConfidenceModel(final SampleList samples,
final SAMFileHeader header,
final int indelInformativeDepthIndelSize,
final SomaticGenotypingEngine genotypingEngine){
SomaticReferenceConfidenceModel(final SampleList samples, final SAMFileHeader header, final int indelInformativeDepthIndelSize,
final double minAF){
super(samples, header, indelInformativeDepthIndelSize, 0);
Utils.validateArg(minAF >= 0.0 && minAF < 1, "minAF must be < 1 and >= 0");

// To softly cut off allele fractions below minAF, we use a Beta prior of the form Beta(1+epsilon, 1); that is
// the prior on allele fraction f is proportional to f^epsilon. If epsilon is small this prior vanishes as f -> 0
// and very rapidly becomes flat. We choose epsilon such that minAF^epsilon = 0.5.
afPrior = minAF == 0.0 ? Optional.empty() : Optional.of(new BetaDistributionShape(1 - Math.log(2)/Math.log(minAF), 1));
this.samples = samples;
this.genotypingEngine = genotypingEngine;
}

/**
Expand All @@ -60,33 +65,27 @@ public ReferenceConfidenceResult calcGenotypeLikelihoodsOfRefVsAny(final int plo
final SomaticRefVsAnyResult result = new SomaticRefVsAnyResult();
final Map<String, List<GATKRead>> perSampleReadMap = new HashMap<>();
perSampleReadMap.put(samples.getSample(0), pileup.getReads());
final AlleleLikelihoods<GATKRead, Allele> readLikelihoods = new AlleleLikelihoods<>(samples, new IndexedAlleleList<>(Arrays.asList(Allele.create(refBase,true), Allele.NON_REF_ALLELE)), perSampleReadMap);
final AlleleLikelihoods<GATKRead, Allele> readLikelihoods2 = new AlleleLikelihoods<>(samples, new IndexedAlleleList<>(Arrays.asList(Allele.create(refBase,true), Allele.NON_REF_ALLELE)), perSampleReadMap);
final Iterator<PileupElement> pileupIter = pileup.iterator();
for (int i = 0; i < pileup.size(); i++) {
final PileupElement element = pileupIter.next();

final List<Byte> altQuals = new ArrayList<>(pileup.size() / 20);

for (final PileupElement element : pileup) {
if (!element.isDeletion() && element.getQual() <= minBaseQual) {
continue;
}

final boolean isAlt = readsWereRealigned ? isAltAfterAssembly(element, refBase) : isAltBeforeAssembly(element, refBase);
final double nonRefLikelihood;
final double refLikelihood;
if (isAlt) {
nonRefLikelihood = NaturalLogUtils.qualToLogProb(element.getQual());
refLikelihood = NaturalLogUtils.qualToLogErrorProb(element.getQual()) + NaturalLogUtils.LOG_ONE_THIRD;
altQuals.add(element.getQual());
result.nonRefDepth++;
} else {
nonRefLikelihood = NaturalLogUtils.qualToLogErrorProb(element.getQual()) + NaturalLogUtils.LOG_ONE_THIRD;
refLikelihood = NaturalLogUtils.qualToLogProb(element.getQual());
result.refDepth++;
}
readLikelihoods.sampleMatrix(0).set(0, i, nonRefLikelihood);
readLikelihoods2.sampleMatrix(0).set(0, i, refLikelihood);
readLikelihoods2.sampleMatrix(0).set(1, i, nonRefLikelihood);
}
result.lods = genotypingEngine.somaticLogOdds(readLikelihoods.sampleMatrix(0));
PerAlleleCollection<Double> lods2 = genotypingEngine.somaticLogOdds(readLikelihoods2.sampleMatrix(0));
result.lods = lods2;

final double logOdds = Mutect2Engine.logLikelihoodRatio(result.refDepth, altQuals, 1, afPrior);
result.lods = new PerAlleleCollection<>(PerAlleleCollection.Type.ALT_ONLY);
result.lods.set(Allele.NON_REF_ALLELE, logOdds);

return result;
}

Expand Down
Loading

0 comments on commit e09f0d7

Please sign in to comment.