Skip to content

Commit

Permalink
Filter M2 calls that are near other filtered calls on the same haplot…
Browse files Browse the repository at this point in the history
…ype (#5092)
  • Loading branch information
davidbenjamin authored Aug 8, 2018
1 parent dad77ac commit 74d41f9
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 210 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

import java.io.File;
import java.util.Optional;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -84,9 +82,7 @@ public final class FilterMutectCalls extends TwoPassVariantWalker {

private Mutect2FilteringEngine filteringEngine;

private List<FilterResult> firstPassFilterResults;

private Mutect2FilterSummary stats;
private FilteringFirstPass filteringFirstPass;

@Override
public void onTraversalStart() {
Expand All @@ -109,7 +105,7 @@ public void onTraversalStart() {
final Optional<String> normalSample = normalSampleHeaderLine == null ? Optional.empty() : Optional.of(normalSampleHeaderLine.getValue());

filteringEngine = new Mutect2FilteringEngine(MTFAC, tumorSample, normalSample);
firstPassFilterResults = new ArrayList<>();
filteringFirstPass = new FilteringFirstPass(tumorSample);
}

@Override
Expand All @@ -120,18 +116,18 @@ public Object onTraversalSuccess() {
@Override
public void firstPassApply(final VariantContext vc, final ReadsContext readsContext, final ReferenceContext refContext, final FeatureContext fc) {
final FilterResult filterResult = filteringEngine.calculateFilters(MTFAC, vc, Optional.empty());
firstPassFilterResults.add(filterResult);
filteringFirstPass.add(filterResult, vc);
}

@Override
protected void afterFirstPass() {
stats = filteringEngine.calculateFilterStats(firstPassFilterResults, MTFAC.maxFalsePositiveRate);
Mutect2FilterSummary.writeM2FilterSummary(stats, MTFAC.mutect2FilteringStatsTable);
filteringFirstPass.learnModelForSecondPass(MTFAC.maxFalsePositiveRate);
filteringFirstPass.writeM2FilterSummary(MTFAC.mutect2FilteringStatsTable);
}

@Override
public void secondPassApply(final VariantContext vc, final ReadsContext readsContext, final ReferenceContext refContext, final FeatureContext fc) {
final FilterResult filterResult = filteringEngine.calculateFilters(MTFAC, vc, Optional.of(stats));
final FilterResult filterResult = filteringEngine.calculateFilters(MTFAC, vc, Optional.of(filteringFirstPass));
final VariantContextBuilder vcb = new VariantContextBuilder(vc);

vcb.filters(filterResult.getFilters());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package org.broadinstitute.hellbender.tools.walkers.mutect;


import htsjdk.variant.variantcontext.Genotype;
import htsjdk.variant.variantcontext.VariantContext;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;
import org.broadinstitute.hellbender.utils.tsv.DataLine;
import org.broadinstitute.hellbender.utils.tsv.TableColumnCollection;
import org.broadinstitute.hellbender.utils.tsv.TableWriter;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;

import java.io.File;
import java.io.IOException;
import java.util.*;

/**
* Stores the results of the first pass of {@link FilterMutectCalls}, a purely online step in which each variant is
* not "aware" of other variants, and learns various global properties necessary for a more refined second step.
*/
public class FilteringFirstPass {
final List<FilterResult> filterResults;
final Map<String, ImmutablePair<String, Integer>> filteredPhasedCalls;
final Map<String, FilterStats> filterStats;
final String tumorSample;
boolean readyForSecondPass;

public FilteringFirstPass(final String tumorSample) {
filterResults = new ArrayList<>();
filteredPhasedCalls = new HashMap<>();
filterStats = new HashMap<>();
readyForSecondPass = false;
this.tumorSample = tumorSample;
}

public boolean isReadyForSecondPass() { return readyForSecondPass; }

public FilterStats getFilterStats(final String filterName){
Utils.validateArg(filterStats.containsKey(filterName), "invalid filter name: " + filterName);
return filterStats.get(filterName);
}

public boolean isOnFilteredHaplotype(final VariantContext vc, final int maxDistance) {

final Genotype tumorGenotype = vc.getGenotype(tumorSample);

if (!hasPhaseInfo(tumorGenotype)) {
return false;
}

final String pgt = (String) tumorGenotype.getExtendedAttribute(GATKVCFConstants.HAPLOTYPE_CALLER_PHASING_GT_KEY, "");
final String pid = (String) tumorGenotype.getExtendedAttribute(GATKVCFConstants.HAPLOTYPE_CALLER_PHASING_ID_KEY, "");
final int position = vc.getStart();

final Pair<String, Integer> filteredCall = filteredPhasedCalls.get(pid);
if (filteredCall == null) {
return false;
}

// Check that vc occurs on the filtered haplotype
return filteredCall.getLeft().equals(pgt) && Math.abs(filteredCall.getRight() - position) <= maxDistance;
}

public void add(final FilterResult filterResult, final VariantContext vc) {
filterResults.add(filterResult);
final Genotype tumorGenotype = vc.getGenotype(tumorSample);

if (!filterResult.getFilters().isEmpty() && hasPhaseInfo(tumorGenotype)) {
final String pgt = (String) tumorGenotype.getExtendedAttribute(GATKVCFConstants.HAPLOTYPE_CALLER_PHASING_GT_KEY, "");
final String pid = (String) tumorGenotype.getExtendedAttribute(GATKVCFConstants.HAPLOTYPE_CALLER_PHASING_ID_KEY, "");
final int position = vc.getStart();
filteredPhasedCalls.put(pid, new ImmutablePair<>(pgt, position));
}
}

public void learnModelForSecondPass(final double requestedFPR){
final double[] readOrientationPosteriors = getFilterResults().stream()
.filter(r -> r.getFilters().isEmpty())
.mapToDouble(r -> r.getReadOrientationPosterior())
.toArray();

final FilterStats readOrientationFilterStats = calculateThresholdForReadOrientationFilter(readOrientationPosteriors, requestedFPR);
filterStats.put(GATKVCFConstants.READ_ORIENTATION_ARTIFACT_FILTER_NAME, readOrientationFilterStats);
readyForSecondPass = true;
}

/**
*
* Compute the filtering threshold that ensures that the false positive rate among the resulting pass variants
* will not exceed the requested false positive rate
*
* @param posteriors A list of posterior probabilities, which gets sorted
* @param requestedFPR We set the filtering threshold such that the FPR doesn't exceed this value
* @return
*/
public static FilterStats calculateThresholdForReadOrientationFilter(final double[] posteriors, final double requestedFPR){
ParamUtils.isPositiveOrZero(requestedFPR, "requested FPR must be non-negative");
final double thresholdForFilteringNone = 1.0;
final double thresholdForFilteringAll = 0.0;

Arrays.sort(posteriors);

final int numPassingVariants = posteriors.length;
double cumulativeExpectedFPs = 0.0;

for (int i = 0; i < numPassingVariants; i++){
final double posterior = posteriors[i];

// One can show that the cumulative error rate is monotonically increasing in i
final double expectedFPR = (cumulativeExpectedFPs + posterior) / (i + 1);
if (expectedFPR > requestedFPR){
return i > 0 ?
new FilterStats(GATKVCFConstants.READ_ORIENTATION_ARTIFACT_FILTER_NAME, posteriors[i-1],
cumulativeExpectedFPs, i-1, cumulativeExpectedFPs/i, requestedFPR) :
new FilterStats(GATKVCFConstants.READ_ORIENTATION_ARTIFACT_FILTER_NAME, thresholdForFilteringAll,
0.0, 0, 0.0, requestedFPR);
}

cumulativeExpectedFPs += posterior;
}

// If the expected FP rate never exceeded the max tolerable value, then we can let everything pass
return new FilterStats(GATKVCFConstants.READ_ORIENTATION_ARTIFACT_FILTER_NAME, thresholdForFilteringNone,
cumulativeExpectedFPs, numPassingVariants, cumulativeExpectedFPs/numPassingVariants, requestedFPR);
}

public static boolean hasPhaseInfo(final Genotype genotype) {
return genotype.hasExtendedAttribute(GATKVCFConstants.HAPLOTYPE_CALLER_PHASING_GT_KEY) && genotype.hasExtendedAttribute(GATKVCFConstants.HAPLOTYPE_CALLER_PHASING_ID_KEY);
}

public List<FilterResult> getFilterResults() {
return filterResults;
}

public static class FilterStats {
private final String filterName;
private final double threshold;
private final double expectedNumFPs;
private final int numPassingVariants;
private final double expectedFPR;
private final double requestedFPR;

public FilterStats(final String filterName, final double threshold, final double expectedNumFPs,
final int numPassingVariants, final double expectedFPR, final double requestedFPR){
this.filterName = filterName;
this.threshold = threshold;
this.expectedNumFPs = expectedNumFPs;
this.numPassingVariants = numPassingVariants;
this.expectedFPR = expectedFPR;
this.requestedFPR = requestedFPR;
}

public String getFilterName() { return filterName; }

public double getExpectedNumFPs() { return expectedNumFPs; }

public int getNumPassingVariants() { return numPassingVariants; }

public double getThreshold() { return threshold; }

public double getExpectedFPR() { return expectedFPR; }

public double getRequestedFPR() { return requestedFPR; }

}

private enum M2FilterStatsTableColumn {
FILTER_NAME("filter_name"),
THRESHOLD("threshold"),
EXPECTED_FALSE_POSITIVES("expected_fps"),
EXPECTED_FALSE_POSITIVE_RATE("expected_fpr"),
REQUESTED_FALSE_POSITIVE_RATE("requested_fpr"),
NUM_PASSING_VARIANTS("num_passing_variants");

private String columnName;

M2FilterStatsTableColumn(final String columnName) {
this.columnName = columnName;
}

@Override
public String toString() { return columnName; }

public static final TableColumnCollection COLUMNS = new TableColumnCollection((Object[]) values());
}

private static class Mutect2FilterStatsWriter extends TableWriter<FilterStats> {
private Mutect2FilterStatsWriter(final File output) throws IOException {
super(output, M2FilterStatsTableColumn.COLUMNS);
}

@Override
protected void composeLine(final FilterStats stats, final DataLine dataLine) {
dataLine.set(M2FilterStatsTableColumn.FILTER_NAME.toString(), stats.getFilterName())
.set(M2FilterStatsTableColumn.THRESHOLD.toString(), stats.getThreshold())
.set(M2FilterStatsTableColumn.EXPECTED_FALSE_POSITIVES.toString(), stats.getExpectedNumFPs())
.set(M2FilterStatsTableColumn.EXPECTED_FALSE_POSITIVE_RATE.toString(), stats.getExpectedFPR())
.set(M2FilterStatsTableColumn.REQUESTED_FALSE_POSITIVE_RATE.toString(), stats.getRequestedFPR())
.set(M2FilterStatsTableColumn.NUM_PASSING_VARIANTS.toString(), stats.getNumPassingVariants());
}
}

public void writeM2FilterSummary(final File outputTable) {
try (Mutect2FilterStatsWriter writer = new Mutect2FilterStatsWriter(outputTable)) {
writer.writeAllRecords(filterStats.values());
} catch (IOException e) {
throw new UserException(String.format("Encountered an IO exception while writing to %s.", outputTable), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class M2FiltersArgumentCollection extends AssemblyBasedCallerArgumentColl
public static final String UNIQUE_ALT_READ_COUNT_LONG_NAME = "unique-alt-read-count";
public static final String TUMOR_SEGMENTATION_LONG_NAME = "tumor-segmentation";
public static final String ORIENTATION_BIAS_FDR_LONG_NAME = "orientation-bias-fdr"; // FDR = false discovery rate
public static final String MAX_DISTANCE_TO_FILTERED_CALL_ON_SAME_HAPLOTYPE_LONG_NAME = "distance-on-haplotype";

public static final String FILTERING_STATS_LONG_NAME = "stats";

Expand Down Expand Up @@ -124,4 +125,8 @@ public class M2FiltersArgumentCollection extends AssemblyBasedCallerArgumentColl
public File mutect2FilteringStatsTable = new File("Mutect2FilteringStats.tsv");


@Argument(fullName = MAX_DISTANCE_TO_FILTERED_CALL_ON_SAME_HAPLOTYPE_LONG_NAME, optional = true, doc = "On second filtering pass, variants with same PGT and PID tags as a filtered variant within this distance are filtered.")
public int maxDistanceToFilteredCallOnSameHaplotype = 100;


}
Loading

0 comments on commit 74d41f9

Please sign in to comment.