Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed MarkDuplicatesSpark handling of unsorted bams #4732

Merged
merged 6 commits into from
May 9, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public JavaRDD<GATKRead> getParallelReads(final String readFileName, final Strin
}
return null;
}).filter(v1 -> v1 != null);
return putPairsInSamePartition(header, reads);
return putPairsInSamePartition(header, reads, ctx);
}

/**
Expand Down Expand Up @@ -164,7 +164,7 @@ public JavaRDD<GATKRead> getADAMReads(final String inputPath, final TraversalPar
.values();
JavaRDD<GATKRead> readsRdd = recordsRdd.map(record -> new BDGAlignmentRecordToGATKReadAdapter(record, bHeader.getValue()));
JavaRDD<GATKRead> filteredRdd = readsRdd.filter(record -> samRecordOverlaps(record.convertToSAMRecord(header), traversalParameters));
return putPairsInSamePartition(header, filteredRdd);
return putPairsInSamePartition(header, filteredRdd, ctx);
}

/**
Expand Down Expand Up @@ -209,7 +209,7 @@ public boolean accept(Path path) {
* Ensure reads in a pair fall in the same partition (input split), if the reads are queryname-sorted,
* or querygroup sorted, so they are processed together. No shuffle is needed.
*/
JavaRDD<GATKRead> putPairsInSamePartition(final SAMFileHeader header, final JavaRDD<GATKRead> reads) {
public static JavaRDD<GATKRead> putPairsInSamePartition(final SAMFileHeader header, final JavaRDD<GATKRead> reads, final JavaSparkContext ctx) {
if (!ReadUtils.isReadNameGroupedBam(header)) {
return reads;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;
import org.broadinstitute.hellbender.utils.read.markduplicates.DuplicationMetrics;
import org.broadinstitute.hellbender.utils.read.markduplicates.MarkDuplicatesScoringStrategy;
import org.broadinstitute.hellbender.utils.read.markduplicates.OpticalDuplicateFinder;
import org.broadinstitute.hellbender.utils.spark.SparkUtils;
import picard.cmdline.programgroups.ReadDataManipulationProgramGroup;
import scala.Tuple2;

Expand Down Expand Up @@ -64,22 +67,50 @@ public List<ReadFilter> getDefaultReadFilters() {
return Collections.singletonList(ReadFilterLibrary.ALLOW_ALL_READS);
}

/**
* Main method for marking duplicates, takes an JavaRDD of GATKRead and an associated SAMFileHeader with corresponding
* sorting information and returns a new JavaRDD\<GATKRead\> in which all read templates have been marked as duplicates
*
* NOTE: This method expects the incoming reads to be grouped by read name (queryname sorted/querygrouped) and for this
* to be explicitly be set in the the provided header. Furthermore, all the reads in a template must be grouped
* into the same partition or there may be problems duplicate marking.
* If MarkDuplicates detects reads are sorted in some other way, it will perform an extra sort operation first,
* thus it is preferable to input reads to this method sorted for performance reasons.
*
* @param reads input reads to be duplicate marked
* @param header header corresponding to the input reads
* @param scoringStrategy method by which duplicates are detected
* @param opticalDuplicateFinder
* @param numReducers number of partitions to separate the data into
* @param dontMarkUnmappedMates when true, unmapped mates of duplicate fragments will be marked as non-duplicates
* @return A JavaRDD of GATKReads where duplicate flags have been set
*/
public static JavaRDD<GATKRead> mark(final JavaRDD<GATKRead> reads, final SAMFileHeader header,
final MarkDuplicatesScoringStrategy scoringStrategy,
final OpticalDuplicateFinder opticalDuplicateFinder,
final int numReducers, final boolean dontMarkUnmappedMates) {
JavaRDD<GATKRead> sortedReadsForMarking;
SAMFileHeader headerForTool = header.clone();

// If the input isn't queryname sorted, sort it before duplicate marking
if (ReadUtils.isReadNameGroupedBam(header)) {
sortedReadsForMarking = reads;
} else {
headerForTool.setSortOrder(SAMFileHeader.SortOrder.queryname);
sortedReadsForMarking = ReadsSparkSource.putPairsInSamePartition(headerForTool, SparkUtils.querynameSortReads(reads, numReducers), new JavaSparkContext(reads.context()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull the sort onto it's own line. It's not a great idea to hide really expensive operations inline with other calls.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might extract this whole sorting operation into a function, "queryNameSortReadsIfNecessary"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

JavaPairRDD<MarkDuplicatesSparkUtils.IndexPair<String>, Integer> namesOfNonDuplicates = MarkDuplicatesSparkUtils.transformToDuplicateNames(header, scoringStrategy, opticalDuplicateFinder, reads, numReducers);
JavaPairRDD<MarkDuplicatesSparkUtils.IndexPair<String>, Integer> namesOfNonDuplicates = MarkDuplicatesSparkUtils.transformToDuplicateNames(headerForTool, scoringStrategy, opticalDuplicateFinder, sortedReadsForMarking, numReducers);

// Here we explicitly repartition the read names of the unmarked reads to match the partitioning of the original bam
final JavaRDD<Tuple2<String,Integer>> repartitionedReadNames = namesOfNonDuplicates
.mapToPair(pair -> new Tuple2<>(pair._1.getIndex(), new Tuple2<>(pair._1.getValue(),pair._2)))
.partitionBy(new KnownIndexPartitioner(reads.getNumPartitions()))
.partitionBy(new KnownIndexPartitioner(sortedReadsForMarking.getNumPartitions()))
.values();

// Here we combine the original bam with the repartitioned unmarked readnames to produce our marked reads
return reads.zipPartitions(repartitionedReadNames, (readsIter, readNamesIter) -> {
final Map<String,Integer> namesOfNonDuplicateReadsAndOpticalCounts = Utils.stream(readNamesIter).collect(Collectors.toMap(Tuple2::_1,Tuple2::_2));
return sortedReadsForMarking.zipPartitions(repartitionedReadNames, (readsIter, readNamesIter) -> {
final Map<String,Integer> namesOfNonDuplicateReadsAndOpticalCounts = Utils.stream(readNamesIter).collect(Collectors.toMap(Tuple2::_1,Tuple2::_2, (t1,t2) -> {throw new GATKException("Detected multiple mark duplicate records objects corresponding to read with name, this could be the result of readnames spanning more than one partition");}));
return Utils.stream(readsIter).peek(read -> {
// Handle reads that have been marked as non-duplicates (which also get tagged with optical duplicate summary statistics)
if( namesOfNonDuplicateReadsAndOpticalCounts.containsKey(read.getName())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ public int getIndex() {
this.value = value;
this.index = index;
}

@Override
public String toString() {
return "indexpair["+index+","+value.toString()+"]";
}
}

/**
Expand Down Expand Up @@ -167,9 +172,7 @@ private static JavaPairRDD<String, Iterable<IndexPair<GATKRead>>> getReadsGroupe
keyedReads = spanReadsByKey(indexedReads);
} else {
// sort by group and name (incurs a shuffle)
JavaPairRDD<String, IndexPair<GATKRead>> keyReadPairs = indexedReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForRead(
read.getValue()), read));
keyedReads = keyReadPairs.groupByKey(numReducers);
throw new GATKException("MarkDuplicatesSparkUtils.mark() requires input reads to be queryname sorted, yet the header indicated otherwise");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you have it print the sort order it thinks its in?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
return keyedReads;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.broadinstitute.hellbender.utils.read;

import htsjdk.samtools.SAMRecordQueryNameComparator;
import htsjdk.samtools.SAMTag;

import java.io.Serializable;
import java.util.Comparator;

/**
* compare {@link GATKRead} by queryname
* duplicates the exact ordering of {@link SAMRecordQueryNameComparator}
*/
public class ReadQueryNameComparator implements Comparator<GATKRead>, Serializable {
private static final long serialVersionUID = 1L;

@Override
public int compare(final GATKRead read1, final GATKRead read2) {
int cmp = compareReadNames(read1, read2);
if (cmp != 0) {
return cmp;
}

final boolean r1Paired = read1.isPaired();
final boolean r2Paired = read2.isPaired();

if (r1Paired || r2Paired) {
if (!r1Paired) return 1;
else if (!r2Paired) return -1;
else if (read1.isFirstOfPair() && read2.isSecondOfPair()) return -1;
else if (read1.isSecondOfPair() && read2.isFirstOfPair()) return 1;
}

if (read1.isReverseStrand() != read2.isReverseStrand()) {
return (read1.isReverseStrand()? 1: -1);
}
if (read1.isSecondaryAlignment() != read2.isSecondaryAlignment()) {
return read2.isSecondaryAlignment()? -1: 1;
}
if (read1.isSupplementaryAlignment() != read2.isSupplementaryAlignment()) {
return read2.isSupplementaryAlignment() ? -1 : 1;
}
final Integer hitIndex1 = read1.getAttributeAsInteger(SAMTag.HI.name());
final Integer hitIndex2 = read2.getAttributeAsInteger(SAMTag.HI.name());
if (hitIndex1 != null) {
if (hitIndex2 == null) return 1;
else {
cmp = hitIndex1.compareTo(hitIndex2);
if (cmp != 0) return cmp;
}
} else if (hitIndex2 != null) return -1;
return 0;
}

/**
* compare read names lexicographically without any additional tie breakers
*/
public int compareReadNames(final GATKRead read1, final GATKRead read2) {
return read1.getName().compareTo(read2.getName());
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,28 @@ public static JavaRDD<GATKRead> coordinateSortReads(final JavaRDD<GATKRead> read
return readVoidPairs.keys();
}

/**
* Sorts the given reads in queryname sort order.
* @param reads the reads to sort
* @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers
* @return a sorted RDD of reads
*/
public static JavaRDD<GATKRead> querynameSortReads(final JavaRDD<GATKRead> reads, final int numReducers) {
// Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount
// of data going through the shuffle.
final JavaPairRDD<GATKRead, Void> rddReadPairs = reads.mapToPair(read -> new Tuple2<>(read, (Void) null));

// do a total sort so that all the reads in partition i are less than those in partition i+1
final Comparator<GATKRead> comparator = new ReadQueryNameComparator();
final JavaPairRDD<GATKRead, Void> readVoidPairs;
if (numReducers > 0) {
readVoidPairs = rddReadPairs.sortByKey(comparator, true, numReducers);
} else {
readVoidPairs = rddReadPairs.sortByKey(comparator);
}
return readVoidPairs.keys();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should call the edge fixing method. We don't want to give people the option to do it wrong.

Copy link
Collaborator Author

@jamesemery jamesemery May 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

/**
* Sorts the given reads according to the sort order in the header.
* @param reads the reads to sort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ public void testPutPairsInSamePartition(int numPairs, int numPartitions, int num
header.setSortOrder(SAMFileHeader.SortOrder.queryname);
JavaRDD<GATKRead> reads = ctx.parallelize(createPairedReads(ctx, header, numPairs, numReadsInPair), numPartitions);
ReadsSparkSource readsSparkSource = new ReadsSparkSource(ctx);
JavaRDD<GATKRead> pairedReads = readsSparkSource.putPairsInSamePartition(header, reads);
JavaRDD<GATKRead> pairedReads = ReadsSparkSource.putPairsInSamePartition(header, reads, ctx);
List<List<GATKRead>> partitions = pairedReads.mapPartitions((FlatMapFunction<Iterator<GATKRead>, List<GATKRead>>) it ->
Iterators.singletonIterator(Lists.newArrayList(it))).collect();
assertEquals(partitions.size(), numPartitions);
Expand Down Expand Up @@ -336,6 +336,6 @@ public void testReadsPairsSpanningMultiplePartitionsCrash() throws IOException {

JavaRDD<GATKRead> problemReads = ctx.parallelize(reads,5 );
ReadsSparkSource readsSparkSource = new ReadsSparkSource(ctx);
readsSparkSource.putPairsInSamePartition(header, problemReads);
readsSparkSource.putPairsInSamePartition(header, problemReads, ctx);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
import org.broadinstitute.hellbender.utils.read.*;
import org.broadinstitute.hellbender.utils.read.markduplicates.MarkDuplicatesScoringStrategy;
import org.broadinstitute.hellbender.utils.read.markduplicates.OpticalDuplicateFinder;
import org.broadinstitute.hellbender.utils.read.markduplicates.ReadsKey;
import org.broadinstitute.hellbender.GATKBaseTest;
import org.broadinstitute.hellbender.utils.test.SamAssertionUtils;
Expand Down Expand Up @@ -66,4 +67,58 @@ private static Tuple2<String, Iterable<GATKRead>> pairIterable(String key, GATKR
return new Tuple2<>(key, ImmutableList.copyOf(reads));
}

@Test
// Test that asserts the duplicate marking is sorting agnostic, specifically this is testing that when reads are scrambled across
// partitions in the input that all reads in a group are getting properly duplicate marked together as they are for queryname sorted bams
public void testSortOrderParitioningCorrectness() throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo paritioning -> partitioning

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
JavaRDD<GATKRead> unsortedReads = generateUnsortedReads(10000,3, ctx, 100, true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stupid nitpick: spaces here are wonky, and on the next line

JavaRDD<GATKRead> pariedEndsQueryGrouped = generateUnsortedReads(10000,3, ctx,1, false);

SAMFileHeader unsortedHeader = hg19Header.clone();
unsortedHeader.setSortOrder(SAMFileHeader.SortOrder.unsorted);
SAMFileHeader sortedHeader = hg19Header.clone();
sortedHeader.setSortOrder(SAMFileHeader.SortOrder.queryname);

// Using the header flagged as unsorted will result in the reads being sorted again
JavaRDD<GATKRead> unsortedReadsMarked = MarkDuplicatesSpark.mark(unsortedReads,unsortedHeader, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES,new OpticalDuplicateFinder(),100,true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is called unsorted, but isn't it actually coordinate sorted?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the different num reducers? Is that to find issues with edge fixing? If it is, I think you'd be better off with a specific (and possibly similar) test for that. Since we're always generating pairs, it seems to me that they might never get split across partitions if we're creating an even number of partitions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the reason for different numbers of partitions was that when I first wrote this test this test there was no exposed way to do the edge fixing for a queryname sorted bam. I didn't want to deal with the problems of having a mispartitioned bam so I let the queryname sorted reads reside on one partition so the spanning couldn't be wrong. Since this is a test of the coordinate sorted bam marking across partitions and not the edge fixing i'm not worried.

JavaRDD<GATKRead> sortedReadsMarked = MarkDuplicatesSpark.mark(pariedEndsQueryGrouped,sortedHeader, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES,new OpticalDuplicateFinder(),1,true);

Iterator<GATKRead> sortedReadsFinal = sortedReadsMarked.sortBy(GATKRead::commonToString, false, 1).collect().iterator();
Iterator<GATKRead> unsortedReadsFinal = unsortedReadsMarked.sortBy(GATKRead::commonToString, false, 1).collect().iterator();

// Comparing the output reads to ensure they are all duplicate marked correctly
while (sortedReadsFinal.hasNext()) {
GATKRead read1 = sortedReadsFinal.next();
GATKRead read2 = unsortedReadsFinal.next();
Assert.assertEquals(read1.getName(), read2.getName());
Assert.assertEquals(read1.isDuplicate(), read2.isDuplicate());
}
}

private JavaRDD<GATKRead> generateUnsortedReads(int numReadGroups, int numDuplicatesPerGroup, JavaSparkContext ctx, int numPartitions, boolean coordinate) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are sorted... rename to generateReadsWithDuplicates or something like that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a bit of javadoc to this method explaining what it does

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int readNameCounter = 0;
SAMRecordSetBuilder samRecordSetBuilder = new SAMRecordSetBuilder(true, SAMFileHeader.SortOrder.coordinate,
true, SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH, SAMRecordSetBuilder.DEFAULT_DUPLICATE_SCORING_STRATEGY);

Random rand = new Random(10);
for (int i = 0; i < numReadGroups; i++ ) {
int start1 = rand.nextInt(SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH);
int start2 = rand.nextInt(SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH);
for (int j = 0; j < numDuplicatesPerGroup; j++) {
samRecordSetBuilder.addPair("READ" + readNameCounter++, 0, start1, start2);
}
}
final ReadCoordinateComparator coordinateComparitor = new ReadCoordinateComparator(hg19Header);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coordinateComparitor is unused, and misspelled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it really does have a poor lot in life doesn't it

List<SAMRecord> records = Lists.newArrayList(samRecordSetBuilder.getRecords());
if (coordinate) {
records.sort(new SAMRecordCoordinateComparator());
} else {
records.sort(new SAMRecordQueryNameComparator());
}

return ctx.parallelize(records, numPartitions).map(SAMRecordToGATKReadAdapter::new);
}

}
Loading