From 63f22f161a860c77de3bb94a7726bb6d0a4e202e Mon Sep 17 00:00:00 2001 From: Louis Bergelson Date: Thu, 10 May 2018 16:05:07 -0400 Subject: [PATCH 1/3] fixes and refactoring methods in SparkUtils * fix partitioning bug by moving edge fixing from coordinateSortReads -> querynameSortReads * refactor methods to reduce code duplication * renaming and moving some methods * disallow duplicate sort order because it doesn't work with headerless reads --- .../spark/datasources/ReadsSparkSink.java | 5 +- .../spark/datasources/ReadsSparkSource.java | 84 ++------- .../markduplicates/MarkDuplicatesSpark.java | 5 +- .../hellbender/utils/spark/SparkUtils.java | 151 +++++++++++----- .../hellbender/utils/test/BaseTest.java | 22 +++ .../datasources/ReadsSparkSinkUnitTest.java | 4 +- .../datasources/ReadsSparkSourceUnitTest.java | 70 +------- .../utils/spark/SparkUtilsUnitTest.java | 168 +++++++++++++++++- .../hellbender/utils/test/ReadTestUtils.java | 9 +- 9 files changed, 315 insertions(+), 203 deletions(-) diff --git a/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java b/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java index 8268ff99009..871e99c5bd3 100644 --- a/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java +++ b/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java @@ -22,12 +22,10 @@ import org.bdgenomics.formats.avro.AlignmentRecord; import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; -import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine; import org.broadinstitute.hellbender.utils.gcs.BucketUtils; import org.broadinstitute.hellbender.utils.io.IOUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.read.GATKReadToBDGAlignmentRecordConverter; -import org.broadinstitute.hellbender.utils.read.HeaderlessSAMRecordCoordinateComparator; import org.broadinstitute.hellbender.utils.read.ReadsWriteFormat; import org.broadinstitute.hellbender.utils.spark.SparkUtils; import org.seqdoop.hadoop_bam.*; @@ -36,7 +34,6 @@ import java.io.File; import java.io.IOException; -import java.util.Comparator; /** * ReadsSparkSink writes GATKReads to a file. This code lifts from the HadoopGenomics/Hadoop-BAM @@ -287,7 +284,7 @@ private static void writeReadsSingle( final JavaSparkContext ctx, final String outputFile, final String referenceFile, final SAMFormat samOutputFormat, final JavaRDD reads, final SAMFileHeader header, final int numReducers, final String outputPartsDir) throws IOException { - final JavaRDD sortedReads = SparkUtils.sortReads(reads, header, numReducers); + final JavaRDD sortedReads = SparkUtils.sortSamRecordsToMatchHeader(reads, header, numReducers); final String outputPartsDirectory = (outputPartsDir == null)? getDefaultPartsDirectory(outputFile) : outputPartsDir; saveAsShardedHadoopFiles(ctx, outputPartsDirectory, referenceFile, samOutputFormat, sortedReads, header, false); logger.info("Finished sorting the bam file and dumping read shards to disk, proceeding to merge the shards into a single file using the master thread"); diff --git a/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSource.java b/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSource.java index 354ba8adcf5..46d1f71bb93 100644 --- a/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSource.java +++ b/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSource.java @@ -1,7 +1,5 @@ package org.broadinstitute.hellbender.engine.spark.datasources; -import com.google.common.collect.Iterators; -import com.google.common.collect.PeekingIterator; import htsjdk.samtools.*; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; @@ -16,12 +14,10 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.FlatMapFunction2; import org.apache.spark.broadcast.Broadcast; import org.bdgenomics.formats.avro.AlignmentRecord; import org.broadinstitute.hellbender.engine.ReadsDataSource; import org.broadinstitute.hellbender.engine.TraversalParameters; -import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; import org.broadinstitute.hellbender.utils.SimpleInterval; import org.broadinstitute.hellbender.utils.gcs.BucketUtils; @@ -37,9 +33,8 @@ import java.io.File; import java.io.IOException; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Iterator; import java.util.List; +import java.util.Objects; /** Loads the reads from disk either serially (using samReaderFactory) or in parallel using Hadoop-BAM. * The parallel code is a modified version of the example writing code from Hadoop-BAM. @@ -116,8 +111,17 @@ public JavaRDD getParallelReads(final String readFileName, final Strin return (GATKRead) SAMRecordToGATKReadAdapter.headerlessReadAdapter(sam); } return null; - }).filter(v1 -> v1 != null); - return putPairsInSamePartition(header, reads, ctx); + }).filter(Objects::nonNull); + + return fixPartitionsIfQueryGrouped(ctx, header, reads); + } + + private static JavaRDD fixPartitionsIfQueryGrouped(JavaSparkContext ctx, SAMFileHeader header, JavaRDD reads) { + if( ReadUtils.isReadNameGroupedBam(header)) { + return SparkUtils.putReadsWithTheSameNameInTheSamePartition(header, reads, ctx); + } else { + return reads; + } } /** @@ -164,7 +168,8 @@ public JavaRDD getADAMReads(final String inputPath, final TraversalPar .values(); JavaRDD readsRdd = recordsRdd.map(record -> new BDGAlignmentRecordToGATKReadAdapter(record, bHeader.getValue())); JavaRDD filteredRdd = readsRdd.filter(record -> samRecordOverlaps(record.convertToSAMRecord(header), traversalParameters)); - return putPairsInSamePartition(header, filteredRdd, ctx); + + return fixPartitionsIfQueryGrouped(ctx, header, filteredRdd); } /** @@ -205,67 +210,6 @@ public boolean accept(Path path) { } } - /** - * Ensure reads in a pair fall in the same partition (input split), if the reads are queryname-sorted, - * or querygroup sorted, so they are processed together. No shuffle is needed. - */ - public static JavaRDD putPairsInSamePartition(final SAMFileHeader header, final JavaRDD reads, final JavaSparkContext ctx) { - if (!ReadUtils.isReadNameGroupedBam(header)) { - return reads; - } - int numPartitions = reads.getNumPartitions(); - final String firstGroupInBam = reads.first().getName(); - // Find the first group in each partition - List> firstReadNamesInEachPartition = reads - .mapPartitions(it -> { PeekingIterator current = Iterators.peekingIterator(it); - List firstGroup = new ArrayList<>(2); - firstGroup.add(current.next()); - String name = firstGroup.get(0).getName(); - while (current.hasNext() && current.peek().getName().equals(name)) { - firstGroup.add(current.next()); - } - return Iterators.singletonIterator(firstGroup); - }) - .collect(); - - // Checking for pathological cases (read name groups that span more than 2 partitions) - String groupName = null; - for (List group : firstReadNamesInEachPartition) { - if (group!=null && !group.isEmpty()) { - // If a read spans multiple partitions we expect its name to show up multiple times and we don't expect this to work properly - if (groupName != null && group.get(0).getName().equals(groupName)) { - throw new GATKException(String.format("The read name '%s' appeared across multiple partitions this could indicate there was a problem " + - "with the sorting or that the rdd has too many partitions, check that the file is queryname sorted and consider decreasing the number of partitions", groupName)); - } - groupName = group.get(0).getName(); - } - } - - // Shift left, so that each partition will be joined with the first read group from the _next_ partition - List> firstReadInNextPartition = new ArrayList<>(firstReadNamesInEachPartition.subList(1, numPartitions)); - firstReadInNextPartition.add(null); // the last partition does not have any reads to add to it - - // Join the reads with the first read from the _next_ partition, then filter out the first and/or last read if not in a pair - return reads.zipPartitions(ctx.parallelize(firstReadInNextPartition, numPartitions), - (FlatMapFunction2, Iterator>, GATKRead>) (it1, it2) -> { - PeekingIterator current = Iterators.peekingIterator(it1); - String firstName = current.peek().getName(); - // Make sure we don't remove reads from the first partition - if (!firstGroupInBam.equals(firstName)) { - // skip the first read name group in the _current_ partition if it is the second in a pair since it will be handled in the previous partition - while (current.hasNext() && current.peek() != null && current.peek().getName().equals(firstName)) { - current.next(); - } - } - // append the first reads in the _next_ partition to the _current_ partition - PeekingIterator> next = Iterators.peekingIterator(it2); - if (next.hasNext() && next.peek() != null) { - return Iterators.concat(current, next.peek().iterator()); - } - return current; - }); - } - /** * Propagate any values that need to be passed to Hadoop-BAM through configuration properties: * diff --git a/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/markduplicates/MarkDuplicatesSpark.java b/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/markduplicates/MarkDuplicatesSpark.java index d9697b89ae5..e85a86b3613 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/markduplicates/MarkDuplicatesSpark.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/markduplicates/MarkDuplicatesSpark.java @@ -133,14 +133,13 @@ public static JavaRDD mark(final JavaRDD reads, final SAMFil /** * Sort reads into queryname order if they are not already sorted */ - protected static JavaRDD querynameSortReadsIfNecessary(JavaRDD reads, int numReducers, SAMFileHeader headerForTool) { + private static JavaRDD querynameSortReadsIfNecessary(JavaRDD reads, int numReducers, SAMFileHeader headerForTool) { JavaRDD sortedReadsForMarking; if (ReadUtils.isReadNameGroupedBam(headerForTool)) { sortedReadsForMarking = reads; } else { headerForTool.setSortOrder(SAMFileHeader.SortOrder.queryname); - JavaRDD sortedReads = SparkUtils.querynameSortReads(reads, numReducers); - sortedReadsForMarking = ReadsSparkSource.putPairsInSamePartition(headerForTool, sortedReads, JavaSparkContext.fromSparkContext(reads.context())); + sortedReadsForMarking = SparkUtils.querynameSortReads(reads, headerForTool, numReducers); } return sortedReadsForMarking; } diff --git a/src/main/java/org/broadinstitute/hellbender/utils/spark/SparkUtils.java b/src/main/java/org/broadinstitute/hellbender/utils/spark/SparkUtils.java index d281033c79f..530d217d5c9 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/spark/SparkUtils.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/spark/SparkUtils.java @@ -1,5 +1,7 @@ package org.broadinstitute.hellbender.utils.spark; +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; import htsjdk.samtools.SAMFileHeader; import htsjdk.samtools.SAMRecord; import htsjdk.samtools.SAMSequenceRecord; @@ -15,9 +17,10 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction2; import org.apache.spark.broadcast.Broadcast; import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink; -import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource; +import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; import org.broadinstitute.hellbender.utils.read.*; import org.broadinstitute.hellbender.utils.Utils; @@ -25,7 +28,10 @@ import java.io.*; +import java.util.ArrayList; import java.util.Comparator; +import java.util.Iterator; +import java.util.List; /** * Miscellaneous Spark-related utilities @@ -130,79 +136,132 @@ public static boolean pathExists(final JavaSparkContext ctx, final Path targetPa public static JavaRDD coordinateSortReads(final JavaRDD reads, final SAMFileHeader header, final int numReducers) { Utils.validate(header.getSortOrder().equals(SAMFileHeader.SortOrder.coordinate), "Header must specify coordinate sort order, but was" + header.getSortOrder()); - // Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount - // of data going through the shuffle. - final JavaPairRDD rddReadPairs = reads.mapToPair(read -> new Tuple2<>(read, (Void) null)); - - // do a total sort so that all the reads in partition i are less than those in partition i+1 - final Comparator comparator = new ReadCoordinateComparator(header); - final JavaPairRDD readVoidPairs; - final JavaRDD output; - if (numReducers > 0) { - readVoidPairs = rddReadPairs.sortByKey(comparator, true, numReducers); - output = ReadsSparkSource.putPairsInSamePartition(header, readVoidPairs.keys(), new JavaSparkContext(readVoidPairs.context())); - } else { - readVoidPairs = rddReadPairs.sortByKey(comparator); - output = readVoidPairs.keys(); - } - return output; + return sort(reads, new ReadCoordinateComparator(header), numReducers); } /** * Sorts the given reads in queryname sort order. + * This guarantees that all reads that have the same read name are placed into the same partition * @param reads the reads to sort + * @param header header of the bam, the header is required to have been set to queryname order * @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers * @return a sorted RDD of reads */ - public static JavaRDD querynameSortReads(final JavaRDD reads, final int numReducers) { - // Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount - // of data going through the shuffle. - final JavaPairRDD rddReadPairs = reads.mapToPair(read -> new Tuple2<>(read, (Void) null)); + public static JavaRDD querynameSortReads(final JavaRDD reads, SAMFileHeader header, final int numReducers) { + Utils.validate(header.getSortOrder().equals(SAMFileHeader.SortOrder.queryname), "Header must specify queryname sort order, but was " + header.getSortOrder()); - // do a total sort so that all the reads in partition i are less than those in partition i+1 - final Comparator comparator = new ReadQueryNameComparator(); - final JavaPairRDD readVoidPairs; - if (numReducers > 0) { - readVoidPairs = rddReadPairs.sortByKey(comparator, true, numReducers); - } else { - readVoidPairs = rddReadPairs.sortByKey(comparator); - } - return readVoidPairs.keys(); + final JavaRDD sortedReads = sort(reads, new ReadQueryNameComparator(), numReducers); + return putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext(reads.context())); } /** * Sorts the given reads according to the sort order in the header. * @param reads the reads to sort - * @param header the header specifying the sort order - * @param numReducers the number of reducers to use; a vlue of 0 means use the default number of reducers + * @param header the header specifying the sort order, + * if the header specifies {@link SAMFileHeader.SortOrder#unsorted} or {@link SAMFileHeader.SortOrder#unknown} + * then no sort will be performed + * @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers * @return a sorted RDD of reads */ - public static JavaRDD sortReads(final JavaRDD reads, final SAMFileHeader header, final int numReducers) { - // Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount - // of data going through the shuffle. - final JavaPairRDD rddReadPairs = reads.mapToPair(read -> new Tuple2<>(read, (Void) null)); - - // do a total sort so that all the reads in partition i are less than those in partition i+1 + public static JavaRDD sortSamRecordsToMatchHeader(final JavaRDD reads, final SAMFileHeader header, final int numReducers) { final Comparator comparator = getSAMRecordComparator(header); - final JavaPairRDD readVoidPairs; - if (comparator == null){ - readVoidPairs = rddReadPairs; //no sort - } else if (numReducers > 0) { - readVoidPairs = rddReadPairs.sortByKey(comparator, true, numReducers); + if ( comparator == null ) { + return reads; } else { - readVoidPairs = rddReadPairs.sortByKey(comparator); + return sort(reads, comparator, numReducers); } - return readVoidPairs.keys(); } //Returns the comparator to use or null if no sorting is required. private static Comparator getSAMRecordComparator(final SAMFileHeader header) { switch (header.getSortOrder()){ case coordinate: return new HeaderlessSAMRecordCoordinateComparator(header); - case duplicate: + //duplicate isn't supported because it doesn't work right on headerless SAMRecords + case duplicate: throw new UserException.UnimplementedFeature("The sort order \"duplicate\" is not supported in Spark."); case queryname: case unsorted: return header.getSortOrder().getComparatorInstance(); default: return null; //NOTE: javac warns if you have this (useless) default BUT it errors out if you remove this default. } } + + /** + * do a total sort of an RDD so that all the elements in partition i are less than those in partition i+1 according to the given comparator + */ + private static JavaRDD sort(JavaRDD reads, Comparator comparator, int numReducers) { + Utils.nonNull(comparator); + Utils.nonNull(reads); + + // Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount + // of data going through the shuffle. + final JavaPairRDD rddReadPairs = reads.mapToPair(read -> new Tuple2<>(read, (Void) null)); + + final JavaPairRDD readVoidPairs; + if (numReducers > 0) { + readVoidPairs = rddReadPairs.sortByKey(comparator, true, numReducers); + } else { + readVoidPairs = rddReadPairs.sortByKey(comparator); + } + return readVoidPairs.keys(); + } + + /** + * Ensure all reads with the same name appear in the same partition. + * Requires that the No shuffle is needed. + + */ + public static JavaRDD putReadsWithTheSameNameInTheSamePartition(final SAMFileHeader header, final JavaRDD reads, final JavaSparkContext ctx) { + Utils.validateArg(ReadUtils.isReadNameGroupedBam(header), () -> "Reads must be queryname grouped or sorted. " + + "Actual sort:" + header.getSortOrder() + " Actual grouping:" +header.getGroupOrder()); + int numPartitions = reads.getNumPartitions(); + final String firstGroupInBam = reads.first().getName(); + // Find the first group in each partition + List> firstReadNamesInEachPartition = reads + .mapPartitions(it -> { PeekingIterator current = Iterators.peekingIterator(it); + List firstGroup = new ArrayList<>(2); + firstGroup.add(current.next()); + String name = firstGroup.get(0).getName(); + while (current.hasNext() && current.peek().getName().equals(name)) { + firstGroup.add(current.next()); + } + return Iterators.singletonIterator(firstGroup); + }) + .collect(); + + // Checking for pathological cases (read name groups that span more than 2 partitions) + String groupName = null; + for (List group : firstReadNamesInEachPartition) { + if (group!=null && !group.isEmpty()) { + // If a read spans multiple partitions we expect its name to show up multiple times and we don't expect this to work properly + if (groupName != null && group.get(0).getName().equals(groupName)) { + throw new GATKException(String.format("The read name '%s' appeared across multiple partitions this could indicate there was a problem " + + "with the sorting or that the rdd has too many partitions, check that the file is queryname sorted and consider decreasing the number of partitions", groupName)); + } + groupName = group.get(0).getName(); + } + } + + // Shift left, so that each partition will be joined with the first read group from the _next_ partition + List> firstReadInNextPartition = new ArrayList<>(firstReadNamesInEachPartition.subList(1, numPartitions)); + firstReadInNextPartition.add(null); // the last partition does not have any reads to add to it + + // Join the reads with the first read from the _next_ partition, then filter out the first and/or last read if not in a pair + return reads.zipPartitions(ctx.parallelize(firstReadInNextPartition, numPartitions), + (FlatMapFunction2, Iterator>, GATKRead>) (it1, it2) -> { + PeekingIterator current = Iterators.peekingIterator(it1); + String firstName = current.peek().getName(); + // Make sure we don't remove reads from the first partition + if (!firstGroupInBam.equals(firstName)) { + // skip the first read name group in the _current_ partition if it is the second in a pair since it will be handled in the previous partition + while (current.hasNext() && current.peek() != null && current.peek().getName().equals(firstName)) { + current.next(); + } + } + // append the first reads in the _next_ partition to the _current_ partition + PeekingIterator> next = Iterators.peekingIterator(it2); + if (next.hasNext() && next.peek() != null) { + return Iterators.concat(current, next.peek().iterator()); + } + return current; + }); + } } diff --git a/src/main/java/org/broadinstitute/hellbender/utils/test/BaseTest.java b/src/main/java/org/broadinstitute/hellbender/utils/test/BaseTest.java index a0937ddcf4b..596e24991a6 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/test/BaseTest.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/test/BaseTest.java @@ -399,6 +399,28 @@ public static void assertCondition(Iterable actual, Iterable expected, } } + /** + * assert that the iterable is sorted according to the comparator + */ + public static void assertSorted(Iterable iterable, Comparator comparator){ + final Iterator iter = iterable.iterator(); + assertSorted(iter, comparator); + } + + /** + * assert that the iterator is sorted according to the comparator + */ + public static void assertSorted(Iterator iterator, Comparator comparator){ + T previous = null; + while(iterator.hasNext()){ + T current = iterator.next(); + if( previous != null) { + Assert.assertTrue(comparator.compare(previous, current) <= 0, "Expected " + previous + " to be <= " + current); + } + previous = current; + } + } + /** * Get a FileSystem that uses the explicit credentials instead of the default * credentials. diff --git a/src/test/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSinkUnitTest.java b/src/test/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSinkUnitTest.java index b0b6b446361..567e240dd13 100644 --- a/src/test/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSinkUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSinkUnitTest.java @@ -18,6 +18,7 @@ import org.broadinstitute.hellbender.utils.read.ReadsWriteFormat; import org.broadinstitute.hellbender.GATKBaseTest; import org.broadinstitute.hellbender.utils.test.MiniClusterUtils; +import org.broadinstitute.hellbender.utils.test.ReadTestUtils; import org.seqdoop.hadoop_bam.SplittingBAMIndexer; import org.testng.Assert; import org.testng.annotations.AfterClass; @@ -25,13 +26,12 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import java.io.BufferedWriter; import java.io.File; -import java.io.FileWriter; import java.io.IOException; import java.nio.file.Files; import java.util.ArrayList; import java.util.Comparator; +import java.util.Iterator; import java.util.List; public class ReadsSparkSinkUnitTest extends GATKBaseTest { diff --git a/src/test/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSourceUnitTest.java b/src/test/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSourceUnitTest.java index b7061683d3d..7664e399a0e 100644 --- a/src/test/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSourceUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSourceUnitTest.java @@ -20,6 +20,7 @@ import org.broadinstitute.hellbender.utils.read.ReadConstants; import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter; import org.broadinstitute.hellbender.GATKBaseTest; +import org.broadinstitute.hellbender.utils.spark.SparkUtils; import org.broadinstitute.hellbender.utils.test.MiniClusterUtils; import org.testng.Assert; import org.testng.annotations.DataProvider; @@ -269,73 +270,4 @@ public JavaRDD getSerialReads(final JavaSparkContext ctx, final String } return ctx.parallelize(records); } - - @DataProvider(name="readPairsAndPartitions") - public Object[][] readPairsAndPartitions() { - return new Object[][] { - // number of pairs, number of partitions, number of reads per pair, expected reads per partition - { 1, 1, 2, new int[] {2} }, - { 2, 2, 2, new int[] {4, 0} }, - { 3, 2, 2, new int[] {4, 2} }, - { 3, 3, 2, new int[] {4, 2, 0} }, - { 6, 2, 2, new int[] {8, 4} }, - { 6, 3, 2, new int[] {6, 4, 2} }, - { 6, 4, 2, new int[] {4, 4, 2, 2} }, - { 2, 2, 3, new int[] {6, 0} }, - { 3, 2, 10, new int[] {20, 10} }, - { 6, 4, 3, new int[] {6, 6, 3, 3} }, - { 20, 7, 5, new int[] {15, 15, 15, 15, 15, 15, 10} }, - }; - } - - @Test(dataProvider = "readPairsAndPartitions") - public void testPutPairsInSamePartition(int numPairs, int numPartitions, int numReadsInPair, int[] expectedReadsPerPartition) throws IOException { - JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); - SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); - header.setSortOrder(SAMFileHeader.SortOrder.queryname); - JavaRDD reads = ctx.parallelize(createPairedReads(ctx, header, numPairs, numReadsInPair), numPartitions); - ReadsSparkSource readsSparkSource = new ReadsSparkSource(ctx); - JavaRDD pairedReads = ReadsSparkSource.putPairsInSamePartition(header, reads, ctx); - List> partitions = pairedReads.mapPartitions((FlatMapFunction, List>) it -> - Iterators.singletonIterator(Lists.newArrayList(it))).collect(); - assertEquals(partitions.size(), numPartitions); - for (int i = 0; i < numPartitions; i++) { - assertEquals(partitions.get(i).size(), expectedReadsPerPartition[i]); - } - assertEquals(Arrays.stream(expectedReadsPerPartition).sum(), numPairs * numReadsInPair); - } - - private List createPairedReads(JavaSparkContext ctx, SAMFileHeader header, int numPairs, int numReadsInPair) { - final int readSize = 151; - final int fragmentLen = 400; - final String templateName = "readpair"; - int leftStart = 10000; - List reads = new ArrayList<>(); - for (int i = 0; i < numPairs;i++) { - leftStart += readSize * 2; - int rightStart = leftStart + fragmentLen - readSize; - reads.addAll(ArtificialReadUtils.createPair(header, templateName + i, readSize, leftStart, rightStart, true, false)); - // Copying a secondary alignment for the second read to fill out the read group - GATKRead readToCopy = reads.get(reads.size()-1).copy(); - readToCopy.setIsSecondaryAlignment(true); - for (int j = 2; j < numReadsInPair; j++) { - reads.add(readToCopy.copy()); - } - } - return reads; - } - - @Test(expectedExceptions = GATKException.class) - public void testReadsPairsSpanningMultiplePartitionsCrash() throws IOException { - JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); - SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); - header.setSortOrder(SAMFileHeader.SortOrder.queryname); - List reads = createPairedReads(ctx, header, 40, 2); - // Creating one group in the middle that should cause problems - reads.addAll(40, createPairedReads(ctx, header, 1, 30)); - - JavaRDD problemReads = ctx.parallelize(reads,5 ); - ReadsSparkSource readsSparkSource = new ReadsSparkSource(ctx); - ReadsSparkSource.putPairsInSamePartition(header, problemReads, ctx); - } } diff --git a/src/test/java/org/broadinstitute/hellbender/utils/spark/SparkUtilsUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/spark/SparkUtilsUnitTest.java index 7284b413c3b..6936e91e160 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/spark/SparkUtilsUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/spark/SparkUtilsUnitTest.java @@ -1,26 +1,41 @@ package org.broadinstitute.hellbender.utils.spark; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; import htsjdk.samtools.*; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import org.apache.spark.Partition; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; import org.broadinstitute.hellbender.engine.ReadsDataSource; import org.broadinstitute.hellbender.engine.spark.SparkContextFactory; +import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; +import org.broadinstitute.hellbender.utils.read.ArtificialReadUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.GATKBaseTest; +import org.broadinstitute.hellbender.utils.read.ReadCoordinateComparator; +import org.broadinstitute.hellbender.utils.read.ReadQueryNameComparator; import org.broadinstitute.hellbender.utils.test.MiniClusterUtils; import org.testng.Assert; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.io.File; import java.io.IOException; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.testng.Assert.assertEquals; public class SparkUtilsUnitTest extends GATKBaseTest { @Test - public void testConvertHeaderlessHadoopBamShardToBam() throws Exception { + public void testConvertHeaderlessHadoopBamShardToBam() { final File bamShard = new File(publicTestDir + "org/broadinstitute/hellbender/utils/spark/reads_data_source_test1.bam.headerless.part-r-00000"); final File output = createTempFile("testConvertHadoopBamShardToBam", ".bam"); final File headerSource = new File(publicTestDir + "org/broadinstitute/hellbender/engine/reads_data_source_test1.bam"); @@ -69,6 +84,157 @@ public void testPathExists() throws Exception { fs.deleteOnExit(tempPath); Assert.assertTrue(SparkUtils.pathExists(ctx, tempPath)); }); + } + + @DataProvider(name="readPairsAndPartitions") + public Object[][] readPairsAndPartitions() { + return new Object[][] { + // number of pairs, number of partitions, number of reads per pair, expected reads per partition + { 1, 1, 2, new int[] {2} }, + { 2, 2, 2, new int[] {4, 0} }, + { 3, 2, 2, new int[] {4, 2} }, + { 3, 3, 2, new int[] {4, 2, 0} }, + { 6, 2, 2, new int[] {8, 4} }, + { 6, 3, 2, new int[] {6, 4, 2} }, + { 6, 4, 2, new int[] {4, 4, 2, 2} }, + { 2, 2, 3, new int[] {6, 0} }, + { 3, 2, 10, new int[] {20, 10} }, + { 6, 4, 3, new int[] {6, 6, 3, 3} }, + { 20, 7, 5, new int[] {15, 15, 15, 15, 15, 15, 10} }, + }; + } + + @Test(dataProvider = "readPairsAndPartitions") + public void testPutReadsWithSameNameInSamePartition(int numPairs, int numPartitions, int numReadsInPair, int[] expectedReadsPerPartition) { + JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); + SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); + header.setSortOrder(SAMFileHeader.SortOrder.queryname); + JavaRDD reads = ctx.parallelize(createPairedReads(header, numPairs, numReadsInPair), numPartitions); + JavaRDD pairedReads = SparkUtils.putReadsWithTheSameNameInTheSamePartition(header, reads, ctx); + List> partitions = pairedReads.mapPartitions((FlatMapFunction, List>) it -> + Iterators.singletonIterator(Lists.newArrayList(it))).collect(); + assertEquals(partitions.size(), numPartitions); + for (int i = 0; i < numPartitions; i++) { + assertEquals(partitions.get(i).size(), expectedReadsPerPartition[i]); + } + assertEquals(Arrays.stream(expectedReadsPerPartition).sum(), numPairs * numReadsInPair); + } + + private static List createPairedReads(SAMFileHeader header, int numPairs, int numReadsInPair) { + final int readSize = 151; + final int fragmentLen = 400; + final String templateName = "readpair"; + int leftStart = 10000; + List reads = new ArrayList<>(); + for (int i = 0; i < numPairs;i++) { + leftStart += readSize * 2; + int rightStart = leftStart + fragmentLen - readSize; + reads.addAll(ArtificialReadUtils.createPair(header, templateName + i, readSize, leftStart, rightStart, true, false)); + // Copying a secondary alignment for the second read to fill out the read group + GATKRead readToCopy = reads.get(reads.size()-1).copy(); + readToCopy.setIsSecondaryAlignment(true); + for (int j = 2; j < numReadsInPair; j++) { + reads.add(readToCopy.copy()); + } + } + return reads; + } + + @Test(expectedExceptions = GATKException.class) + public void testReadsPairsSpanningMultiplePartitionsCrash() { + JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); + SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); + header.setSortOrder(SAMFileHeader.SortOrder.queryname); + List reads = createPairedReads(header, 40, 2); + // Creating one group in the middle that should cause problems + reads.addAll(40, createPairedReads(header, 1, 30)); + + JavaRDD problemReads = ctx.parallelize(reads,5 ); + SparkUtils.putReadsWithTheSameNameInTheSamePartition(header, problemReads, ctx); + } + + @Test + public void testReadsMustBeQueryGroupedToFixPartitions(){ + JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); + SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); + header.setSortOrder(SAMFileHeader.SortOrder.coordinate); + List reads = createPairedReads(header, 40, 2); + final JavaRDD readsRDD = ctx.parallelize(reads, 5); + Assert.assertThrows(IllegalArgumentException.class, () -> SparkUtils.putReadsWithTheSameNameInTheSamePartition(header, readsRDD, ctx)); + } + + @Test + public void testSortCoordinateSort() { + JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); + SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); + List reads = new ArrayList<>(); + for(int i = 0; i < 2000; i++){ + //create reads with alternating contigs and decreasing start position + reads.add(ArtificialReadUtils.createArtificialRead(header, "READ"+i, i % header.getSequenceDictionary().size() , 3000 - i, 100)); + } + final JavaRDD readsRDD = ctx.parallelize(reads); + final List coordinateSorted = SparkUtils.coordinateSortReads(readsRDD, header, 0).collect(); + assertSorted(coordinateSorted, new ReadCoordinateComparator(header)); + assertSorted(coordinateSorted.stream().map(read -> read.convertToSAMRecord(header)).collect(Collectors.toList()), new SAMRecordCoordinateComparator()); + } + + @Test + public void testSortQuerynameSort() { + JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); + SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); + List reads = new ArrayList<>(); + final int numReads = 2000; + for(int i = 0; i < numReads; i++) { + + //create reads with non-lexicographically ordered names + //names are created in lexicographically decreasing order, with 2 repetitions to create "pairs" + reads.add(ArtificialReadUtils.createArtificialRead(header, "READ" + (numReads - i) % (numReads / 2), + i % header.getSequenceDictionary().size(), + 3000 - i, + 100)); + } + header.setSortOrder(SAMFileHeader.SortOrder.queryname); + final JavaRDD readsRDD = ctx.parallelize(reads); + final List querynameSorted = SparkUtils.querynameSortReads(readsRDD, header, 31).collect(); + assertSorted(querynameSorted, new ReadQueryNameComparator()); + assertSorted(querynameSorted.stream().map(read -> read.convertToSAMRecord(header)).collect(Collectors.toList()), new SAMRecordQueryNameComparator()); + } + + @Test + public void testSortQuerynameFixesPartitionBoundaries(){ + JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); + final SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); + header.setSortOrder(SAMFileHeader.SortOrder.queryname); + final int numReadsWithSameName = 4; + final List pairedReads = createPairedReads(header, 100, numReadsWithSameName); + final int numPartitions = 7; + final JavaRDD reads = ctx.parallelize(pairedReads, numPartitions); + + //assert that the grouping is not correct before sorting + final List[] partitions = reads.collectPartitions(IntStream.range(0, reads.getNumPartitions()).toArray()); + Assert.assertTrue( + Arrays.stream(partitions) + //look through each partition and count the number of each read name seen + .flatMap( readsInPartition -> readsInPartition.stream() + .collect(Collectors.groupingBy(GATKRead::getName)) + .values() + .stream() + .map(List::size) + ) + //check that at least one partition was not correctly distributed + .anyMatch(size -> size != numReadsWithSameName), "The partitioning was correct before sorting so the test is meaningless."); + + final JavaRDD sorted = SparkUtils.querynameSortReads(reads, header, numPartitions); + //assert that the grouping is fixed after sorting + final List[] sortedPartitions = sorted.collectPartitions(IntStream.range(0, sorted.getNumPartitions()).toArray()); + Assert.assertTrue(Arrays.stream(sortedPartitions) + .flatMap( readsInPartition -> readsInPartition.stream() + .collect(Collectors.groupingBy(GATKRead::getName)) + .values() + .stream() + .map(List::size) + ) + .allMatch(size -> size == numReadsWithSameName), "Some reads names were split between multiple partitions"); } } diff --git a/src/test/java/org/broadinstitute/hellbender/utils/test/ReadTestUtils.java b/src/test/java/org/broadinstitute/hellbender/utils/test/ReadTestUtils.java index ee1b95808cf..3d465ad717d 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/test/ReadTestUtils.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/test/ReadTestUtils.java @@ -1,15 +1,8 @@ package org.broadinstitute.hellbender.utils.test; -import htsjdk.samtools.Cigar; -import htsjdk.samtools.CigarElement; -import htsjdk.samtools.CigarOperator; -import htsjdk.samtools.SAMFileHeader; -import htsjdk.samtools.SAMRecord; -import htsjdk.samtools.SAMSequenceDictionary; +import htsjdk.samtools.*; import htsjdk.samtools.reference.IndexedFastaSequenceFile; import htsjdk.samtools.util.SequenceUtil; -import org.apache.commons.lang3.tuple.ImmutableTriple; -import org.broadinstitute.hellbender.utils.SimpleInterval; import org.broadinstitute.hellbender.utils.Utils; import org.broadinstitute.hellbender.utils.param.ParamUtils; From 16e617940cc142409d0d37b46a1fafbca39cdf63 Mon Sep 17 00:00:00 2001 From: Louis Bergelson Date: Thu, 17 May 2018 16:42:29 -0400 Subject: [PATCH 2/3] responding to comments --- .../spark/datasources/ReadsSparkSink.java | 34 +++++- .../tools/HaplotypeCallerSpark.java | 2 +- .../spark/pipelines/ReadsPipelineSpark.java | 2 +- .../markduplicates/MarkDuplicatesSpark.java | 7 +- .../hellbender/utils/spark/SparkUtils.java | 107 ++++++------------ .../utils/spark/SparkUtilsUnitTest.java | 36 ++++-- 6 files changed, 104 insertions(+), 84 deletions(-) diff --git a/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java b/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java index 871e99c5bd3..79822331c59 100644 --- a/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java +++ b/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java @@ -26,6 +26,7 @@ import org.broadinstitute.hellbender.utils.io.IOUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.read.GATKReadToBDGAlignmentRecordConverter; +import org.broadinstitute.hellbender.utils.read.HeaderlessSAMRecordCoordinateComparator; import org.broadinstitute.hellbender.utils.read.ReadsWriteFormat; import org.broadinstitute.hellbender.utils.spark.SparkUtils; import org.seqdoop.hadoop_bam.*; @@ -34,6 +35,7 @@ import java.io.File; import java.io.IOException; +import java.util.Comparator; /** * ReadsSparkSink writes GATKReads to a file. This code lifts from the HadoopGenomics/Hadoop-BAM @@ -43,6 +45,36 @@ public final class ReadsSparkSink { private final static Logger logger = LogManager.getLogger(ReadsSparkSink.class); + /** + * Sorts the given reads according to the sort order in the header. + * @param reads the reads to sort + * @param header the header specifying the sort order, + * if the header specifies {@link SAMFileHeader.SortOrder#unsorted} or {@link SAMFileHeader.SortOrder#unknown} + * then no sort will be performed + * @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers + * @return a sorted RDD of reads + */ + public static JavaRDD sortSamRecordsToMatchHeader(final JavaRDD reads, final SAMFileHeader header, final int numReducers) { + final Comparator comparator = getSAMRecordComparator(header); + if ( comparator == null ) { + return reads; + } else { + return SparkUtils.sortUsingElementsAsKeys(reads, comparator, numReducers); + } + } + + //Returns the comparator to use or null if no sorting is required. + private static Comparator getSAMRecordComparator(final SAMFileHeader header) { + switch (header.getSortOrder()){ + case coordinate: return new HeaderlessSAMRecordCoordinateComparator(header); + //duplicate isn't supported because it doesn't work right on headerless SAMRecords + case duplicate: throw new UserException.UnimplementedFeature("The sort order \"duplicate\" is not supported in Spark."); + case queryname: + case unsorted: return header.getSortOrder().getComparatorInstance(); + default: return null; //NOTE: javac warns if you have this (useless) default BUT it errors out if you remove this default. + } + } + // Output format class for writing BAM files through saveAsNewAPIHadoopFile. Must be public. public static class SparkBAMOutputFormat extends KeyIgnoringBAMOutputFormat { public static SAMFileHeader bamHeader = null; @@ -284,7 +316,7 @@ private static void writeReadsSingle( final JavaSparkContext ctx, final String outputFile, final String referenceFile, final SAMFormat samOutputFormat, final JavaRDD reads, final SAMFileHeader header, final int numReducers, final String outputPartsDir) throws IOException { - final JavaRDD sortedReads = SparkUtils.sortSamRecordsToMatchHeader(reads, header, numReducers); + final JavaRDD sortedReads = sortSamRecordsToMatchHeader(reads, header, numReducers); final String outputPartsDirectory = (outputPartsDir == null)? getDefaultPartsDirectory(outputFile) : outputPartsDir; saveAsShardedHadoopFiles(ctx, outputPartsDirectory, referenceFile, samOutputFormat, sortedReads, header, false); logger.info("Finished sorting the bam file and dumping read shards to disk, proceeding to merge the shards into a single file using the master thread"); diff --git a/src/main/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSpark.java b/src/main/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSpark.java index f18ad39c402..908ddacd263 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSpark.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/HaplotypeCallerSpark.java @@ -182,7 +182,7 @@ public static void callVariantsWithHaplotypeCallerAndWriteOutput( // Reads must be coordinate sorted to use the overlaps partitioner final SAMFileHeader readsHeader = header.clone(); readsHeader.setSortOrder(SAMFileHeader.SortOrder.coordinate); - final JavaRDD coordinateSortedReads = SparkUtils.coordinateSortReads(reads, readsHeader, numReducers); + final JavaRDD coordinateSortedReads = SparkUtils.sortReadsAccordingToHeader(reads, readsHeader, numReducers); final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgs, false, false, readsHeader, new ReferenceMultiSourceAdapter(reference)); final JavaRDD variants = callVariantsWithHaplotypeCaller(ctx, coordinateSortedReads, readsHeader, reference, intervals, hcArgs, shardingArgs); diff --git a/src/main/java/org/broadinstitute/hellbender/tools/spark/pipelines/ReadsPipelineSpark.java b/src/main/java/org/broadinstitute/hellbender/tools/spark/pipelines/ReadsPipelineSpark.java index 690afdf8d74..6fe4637d8f4 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/spark/pipelines/ReadsPipelineSpark.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/spark/pipelines/ReadsPipelineSpark.java @@ -181,7 +181,7 @@ protected void runTool(final JavaSparkContext ctx) { // the overlaps partitioner requires that reads are coordinate-sorted final SAMFileHeader readsHeader = header.clone(); readsHeader.setSortOrder(SAMFileHeader.SortOrder.coordinate); - markedFilteredReadsForBQSR = SparkUtils.coordinateSortReads(markedFilteredReadsForBQSR, readsHeader, numReducers); + markedFilteredReadsForBQSR = SparkUtils.sortReadsAccordingToHeader(markedFilteredReadsForBQSR, readsHeader, numReducers); } VariantsSparkSource variantsSparkSource = new VariantsSparkSource(ctx); diff --git a/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/markduplicates/MarkDuplicatesSpark.java b/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/markduplicates/MarkDuplicatesSpark.java index e85a86b3613..a19159d1e30 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/markduplicates/MarkDuplicatesSpark.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/markduplicates/MarkDuplicatesSpark.java @@ -90,13 +90,12 @@ public static JavaRDD mark(final JavaRDD reads, final SAMFil final MarkDuplicatesScoringStrategy scoringStrategy, final OpticalDuplicateFinder opticalDuplicateFinder, final int numReducers, final boolean dontMarkUnmappedMates) { - JavaRDD sortedReadsForMarking; SAMFileHeader headerForTool = header.clone(); // If the input isn't queryname sorted, sort it before duplicate marking - sortedReadsForMarking = querynameSortReadsIfNecessary(reads, numReducers, headerForTool); + final JavaRDD sortedReadsForMarking = querynameSortReadsIfNecessary(reads, numReducers, headerForTool); - JavaPairRDD, Integer> namesOfNonDuplicates = MarkDuplicatesSparkUtils.transformToDuplicateNames(headerForTool, scoringStrategy, opticalDuplicateFinder, sortedReadsForMarking, numReducers); + final JavaPairRDD, Integer> namesOfNonDuplicates = MarkDuplicatesSparkUtils.transformToDuplicateNames(headerForTool, scoringStrategy, opticalDuplicateFinder, sortedReadsForMarking, numReducers); // Here we explicitly repartition the read names of the unmarked reads to match the partitioning of the original bam final JavaRDD> repartitionedReadNames = namesOfNonDuplicates @@ -139,7 +138,7 @@ private static JavaRDD querynameSortReadsIfNecessary(JavaRDD sortedReadsForMarking = reads; } else { headerForTool.setSortOrder(SAMFileHeader.SortOrder.queryname); - sortedReadsForMarking = SparkUtils.querynameSortReads(reads, headerForTool, numReducers); + sortedReadsForMarking = SparkUtils.sortReadsAccordingToHeader(reads, headerForTool, numReducers); } return sortedReadsForMarking; } diff --git a/src/main/java/org/broadinstitute/hellbender/utils/spark/SparkUtils.java b/src/main/java/org/broadinstitute/hellbender/utils/spark/SparkUtils.java index 530d217d5c9..586f4904d2f 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/spark/SparkUtils.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/spark/SparkUtils.java @@ -3,7 +3,6 @@ import com.google.common.collect.Iterators; import com.google.common.collect.PeekingIterator; import htsjdk.samtools.SAMFileHeader; -import htsjdk.samtools.SAMRecord; import htsjdk.samtools.SAMSequenceRecord; import htsjdk.samtools.SAMTextHeaderCodec; import htsjdk.samtools.util.BinaryCodec; @@ -11,8 +10,8 @@ import htsjdk.samtools.util.BlockCompressedStreamConstants; import htsjdk.samtools.util.RuntimeIOException; import org.apache.commons.io.FileUtils; -import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; import org.apache.log4j.Logger; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; @@ -22,11 +21,10 @@ import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink; import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; -import org.broadinstitute.hellbender.utils.read.*; import org.broadinstitute.hellbender.utils.Utils; +import org.broadinstitute.hellbender.utils.read.*; import scala.Tuple2; - import java.io.*; import java.util.ArrayList; import java.util.Comparator; @@ -127,73 +125,39 @@ public static boolean pathExists(final JavaSparkContext ctx, final Path targetPa } /** - * Sorts the given reads in coordinate sort order. - * @param reads the reads to sort - * @param header the reads header, which must specify coordinate sort order - * @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers - * @return a sorted RDD of reads + * Do a total sort of an RDD of {@link GATKRead} according to the sort order in the header. + * @param reads a JavaRDD of reads which may or may not be sorted + * @param header a header which specifies the desired new sort order. + * Only {@link SAMFileHeader.SortOrder#coordinate} and {@link SAMFileHeader.SortOrder#queryname} are supported. + * All others will result in {@link GATKException} + * @param numReducers number of reducers to use when sorting + * @return a new JavaRDD or reads which is globally sorted in a way that is consistent with the sort order given in the header */ - public static JavaRDD coordinateSortReads(final JavaRDD reads, final SAMFileHeader header, final int numReducers) { - Utils.validate(header.getSortOrder().equals(SAMFileHeader.SortOrder.coordinate), "Header must specify coordinate sort order, but was" + header.getSortOrder()); - - return sort(reads, new ReadCoordinateComparator(header), numReducers); - } - - /** - * Sorts the given reads in queryname sort order. - * This guarantees that all reads that have the same read name are placed into the same partition - * @param reads the reads to sort - * @param header header of the bam, the header is required to have been set to queryname order - * @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers - * @return a sorted RDD of reads - */ - public static JavaRDD querynameSortReads(final JavaRDD reads, SAMFileHeader header, final int numReducers) { - Utils.validate(header.getSortOrder().equals(SAMFileHeader.SortOrder.queryname), "Header must specify queryname sort order, but was " + header.getSortOrder()); - - final JavaRDD sortedReads = sort(reads, new ReadQueryNameComparator(), numReducers); - return putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext(reads.context())); - } - - /** - * Sorts the given reads according to the sort order in the header. - * @param reads the reads to sort - * @param header the header specifying the sort order, - * if the header specifies {@link SAMFileHeader.SortOrder#unsorted} or {@link SAMFileHeader.SortOrder#unknown} - * then no sort will be performed - * @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers - * @return a sorted RDD of reads - */ - public static JavaRDD sortSamRecordsToMatchHeader(final JavaRDD reads, final SAMFileHeader header, final int numReducers) { - final Comparator comparator = getSAMRecordComparator(header); - if ( comparator == null ) { - return reads; - } else { - return sort(reads, comparator, numReducers); - } - } - - //Returns the comparator to use or null if no sorting is required. - private static Comparator getSAMRecordComparator(final SAMFileHeader header) { - switch (header.getSortOrder()){ - case coordinate: return new HeaderlessSAMRecordCoordinateComparator(header); - //duplicate isn't supported because it doesn't work right on headerless SAMRecords - case duplicate: throw new UserException.UnimplementedFeature("The sort order \"duplicate\" is not supported in Spark."); + public static JavaRDD sortReadsAccordingToHeader(final JavaRDD reads, final SAMFileHeader header, final int numReducers){ + final SAMFileHeader.SortOrder order = header.getSortOrder(); + switch (order){ + case coordinate: + return sortUsingElementsAsKeys(reads, new ReadCoordinateComparator(header), numReducers); case queryname: - case unsorted: return header.getSortOrder().getComparatorInstance(); - default: return null; //NOTE: javac warns if you have this (useless) default BUT it errors out if you remove this default. + final JavaRDD sortedReads = sortUsingElementsAsKeys(reads, new ReadQueryNameComparator(), numReducers); + return putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext(reads.context())); + default: + throw new GATKException("Sort order: " + order + " is not supported."); } } /** - * do a total sort of an RDD so that all the elements in partition i are less than those in partition i+1 according to the given comparator + * Do a global sort of an RDD using the given comparator. + * This method uses the RDD elements themselves as the keys in the spark key/value sort. This may be inefficient + * if the comparator only uses looks at a small fraction of the element to perform the comparison. */ - private static JavaRDD sort(JavaRDD reads, Comparator comparator, int numReducers) { + public static JavaRDD sortUsingElementsAsKeys(JavaRDD elements, Comparator comparator, int numReducers) { Utils.nonNull(comparator); - Utils.nonNull(reads); + Utils.nonNull(elements); // Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount // of data going through the shuffle. - final JavaPairRDD rddReadPairs = reads.mapToPair(read -> new Tuple2<>(read, (Void) null)); + final JavaPairRDD rddReadPairs = elements.mapToPair(read -> new Tuple2<>(read, (Void) null)); final JavaPairRDD readVoidPairs; if (numReducers > 0) { @@ -205,17 +169,20 @@ private static JavaRDD sort(JavaRDD reads, Comparator comparator, i } /** - * Ensure all reads with the same name appear in the same partition. - * Requires that the No shuffle is needed. - + * Ensure all reads with the same name appear in the same partition of a queryname sorted RDD. + * This avoids a global shuffle and only transfers the leading elements from each partition which is fast in most + * cases. + * + * The RDD must be queryname sorted. If there are so many reads with the same name that they span multiple partitions + * this will throw {@link GATKException}. */ public static JavaRDD putReadsWithTheSameNameInTheSamePartition(final SAMFileHeader header, final JavaRDD reads, final JavaSparkContext ctx) { Utils.validateArg(ReadUtils.isReadNameGroupedBam(header), () -> "Reads must be queryname grouped or sorted. " + "Actual sort:" + header.getSortOrder() + " Actual grouping:" +header.getGroupOrder()); int numPartitions = reads.getNumPartitions(); - final String firstGroupInBam = reads.first().getName(); + final String firstNameInBam = reads.first().getName(); // Find the first group in each partition - List> firstReadNamesInEachPartition = reads + List> firstReadNameGroupInEachPartition = reads .mapPartitions(it -> { PeekingIterator current = Iterators.peekingIterator(it); List firstGroup = new ArrayList<>(2); firstGroup.add(current.next()); @@ -229,7 +196,7 @@ public static JavaRDD putReadsWithTheSameNameInTheSamePartition(final // Checking for pathological cases (read name groups that span more than 2 partitions) String groupName = null; - for (List group : firstReadNamesInEachPartition) { + for (List group : firstReadNameGroupInEachPartition) { if (group!=null && !group.isEmpty()) { // If a read spans multiple partitions we expect its name to show up multiple times and we don't expect this to work properly if (groupName != null && group.get(0).getName().equals(groupName)) { @@ -241,17 +208,17 @@ public static JavaRDD putReadsWithTheSameNameInTheSamePartition(final } // Shift left, so that each partition will be joined with the first read group from the _next_ partition - List> firstReadInNextPartition = new ArrayList<>(firstReadNamesInEachPartition.subList(1, numPartitions)); + List> firstReadInNextPartition = new ArrayList<>(firstReadNameGroupInEachPartition.subList(1, numPartitions)); firstReadInNextPartition.add(null); // the last partition does not have any reads to add to it - // Join the reads with the first read from the _next_ partition, then filter out the first and/or last read if not in a pair + // Join the reads with the first read from the _next_ partition, then filter out the first reads in this partition return reads.zipPartitions(ctx.parallelize(firstReadInNextPartition, numPartitions), (FlatMapFunction2, Iterator>, GATKRead>) (it1, it2) -> { PeekingIterator current = Iterators.peekingIterator(it1); String firstName = current.peek().getName(); // Make sure we don't remove reads from the first partition - if (!firstGroupInBam.equals(firstName)) { - // skip the first read name group in the _current_ partition if it is the second in a pair since it will be handled in the previous partition + if (!firstNameInBam.equals(firstName)) { + // skip the first read name group in the _current_ partition since it will be handled in the previous partition while (current.hasNext() && current.peek() != null && current.peek().getName().equals(firstName)) { current.next(); } diff --git a/src/test/java/org/broadinstitute/hellbender/utils/spark/SparkUtilsUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/spark/SparkUtilsUnitTest.java index 6936e91e160..0d72e7b8a54 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/spark/SparkUtilsUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/spark/SparkUtilsUnitTest.java @@ -6,17 +6,16 @@ import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.spark.Partition; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; +import org.broadinstitute.hellbender.GATKBaseTest; import org.broadinstitute.hellbender.engine.ReadsDataSource; import org.broadinstitute.hellbender.engine.spark.SparkContextFactory; import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; import org.broadinstitute.hellbender.utils.read.ArtificialReadUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; -import org.broadinstitute.hellbender.GATKBaseTest; import org.broadinstitute.hellbender.utils.read.ReadCoordinateComparator; import org.broadinstitute.hellbender.utils.read.ReadQueryNameComparator; import org.broadinstitute.hellbender.utils.test.MiniClusterUtils; @@ -163,8 +162,22 @@ public void testReadsMustBeQueryGroupedToFixPartitions(){ Assert.assertThrows(IllegalArgumentException.class, () -> SparkUtils.putReadsWithTheSameNameInTheSamePartition(header, readsRDD, ctx)); } + @Test(expectedExceptions = GATKException.class) + public void testInvalidSort(){ + JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); + SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); + header.setSortOrder(SAMFileHeader.SortOrder.unsorted); + List reads = new ArrayList<>(); + for(int i = 0; i < 10; i++){ + //create reads with alternating contigs and decreasing start position + reads.add(ArtificialReadUtils.createArtificialRead(header, "READ"+i, i % header.getSequenceDictionary().size() , 100, 100)); + } + final JavaRDD readsRDD = ctx.parallelize(reads); + SparkUtils.sortReadsAccordingToHeader(readsRDD, header, 0); + } + @Test - public void testSortCoordinateSort() { + public void testSortCoordinateSortMatchesHtsjdk() { JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); List reads = new ArrayList<>(); @@ -173,13 +186,13 @@ public void testSortCoordinateSort() { reads.add(ArtificialReadUtils.createArtificialRead(header, "READ"+i, i % header.getSequenceDictionary().size() , 3000 - i, 100)); } final JavaRDD readsRDD = ctx.parallelize(reads); - final List coordinateSorted = SparkUtils.coordinateSortReads(readsRDD, header, 0).collect(); + final List coordinateSorted = SparkUtils.sortReadsAccordingToHeader(readsRDD, header, 0).collect(); assertSorted(coordinateSorted, new ReadCoordinateComparator(header)); assertSorted(coordinateSorted.stream().map(read -> read.convertToSAMRecord(header)).collect(Collectors.toList()), new SAMRecordCoordinateComparator()); } @Test - public void testSortQuerynameSort() { + public void testSortQuerynameSortMatchesHtsjdk() { JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader(); List reads = new ArrayList<>(); @@ -195,7 +208,7 @@ public void testSortQuerynameSort() { } header.setSortOrder(SAMFileHeader.SortOrder.queryname); final JavaRDD readsRDD = ctx.parallelize(reads); - final List querynameSorted = SparkUtils.querynameSortReads(readsRDD, header, 31).collect(); + final List querynameSorted = SparkUtils.sortReadsAccordingToHeader(readsRDD, header, 31).collect(); assertSorted(querynameSorted, new ReadQueryNameComparator()); assertSorted(querynameSorted.stream().map(read -> read.convertToSAMRecord(header)).collect(Collectors.toList()), new SAMRecordQueryNameComparator()); } @@ -224,7 +237,7 @@ public void testSortQuerynameFixesPartitionBoundaries(){ //check that at least one partition was not correctly distributed .anyMatch(size -> size != numReadsWithSameName), "The partitioning was correct before sorting so the test is meaningless."); - final JavaRDD sorted = SparkUtils.querynameSortReads(reads, header, numPartitions); + final JavaRDD sorted = SparkUtils.sortReadsAccordingToHeader(reads, header, numPartitions); //assert that the grouping is fixed after sorting final List[] sortedPartitions = sorted.collectPartitions(IntStream.range(0, sorted.getNumPartitions()).toArray()); @@ -237,4 +250,13 @@ public void testSortQuerynameFixesPartitionBoundaries(){ ) .allMatch(size -> size == numReadsWithSameName), "Some reads names were split between multiple partitions"); } + + @Test + public void testSortUsingElementsAsKeys(){ + final List unsorted = Arrays.asList(4, 2, 6, 0, 8); + final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); + final JavaRDD unsortedRDD = ctx.parallelize(unsorted); + final JavaRDD sorted = SparkUtils.sortUsingElementsAsKeys(unsortedRDD, Comparator.naturalOrder(), 2); + assertSorted(sorted.collect(), Comparator.naturalOrder()); + } } From 17bdc1ca7579f41b63b0498330f09687049cc645 Mon Sep 17 00:00:00 2001 From: Louis Bergelson Date: Fri, 18 May 2018 15:52:18 -0400 Subject: [PATCH 3/3] make method private --- .../spark/datasources/ReadsSparkSink.java | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java b/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java index 79822331c59..c6d596a7262 100644 --- a/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java +++ b/src/main/java/org/broadinstitute/hellbender/engine/spark/datasources/ReadsSparkSink.java @@ -45,36 +45,6 @@ public final class ReadsSparkSink { private final static Logger logger = LogManager.getLogger(ReadsSparkSink.class); - /** - * Sorts the given reads according to the sort order in the header. - * @param reads the reads to sort - * @param header the header specifying the sort order, - * if the header specifies {@link SAMFileHeader.SortOrder#unsorted} or {@link SAMFileHeader.SortOrder#unknown} - * then no sort will be performed - * @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers - * @return a sorted RDD of reads - */ - public static JavaRDD sortSamRecordsToMatchHeader(final JavaRDD reads, final SAMFileHeader header, final int numReducers) { - final Comparator comparator = getSAMRecordComparator(header); - if ( comparator == null ) { - return reads; - } else { - return SparkUtils.sortUsingElementsAsKeys(reads, comparator, numReducers); - } - } - - //Returns the comparator to use or null if no sorting is required. - private static Comparator getSAMRecordComparator(final SAMFileHeader header) { - switch (header.getSortOrder()){ - case coordinate: return new HeaderlessSAMRecordCoordinateComparator(header); - //duplicate isn't supported because it doesn't work right on headerless SAMRecords - case duplicate: throw new UserException.UnimplementedFeature("The sort order \"duplicate\" is not supported in Spark."); - case queryname: - case unsorted: return header.getSortOrder().getComparatorInstance(); - default: return null; //NOTE: javac warns if you have this (useless) default BUT it errors out if you remove this default. - } - } - // Output format class for writing BAM files through saveAsNewAPIHadoopFile. Must be public. public static class SparkBAMOutputFormat extends KeyIgnoringBAMOutputFormat { public static SAMFileHeader bamHeader = null; @@ -395,4 +365,34 @@ public static String getDefaultPartsDirectory(String file) { return file + ".parts/"; } + /** + * Sorts the given reads according to the sort order in the header. + * @param reads the reads to sort + * @param header the header specifying the sort order, + * if the header specifies {@link SAMFileHeader.SortOrder#unsorted} or {@link SAMFileHeader.SortOrder#unknown} + * then no sort will be performed + * @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers + * @return a sorted RDD of reads + */ + private static JavaRDD sortSamRecordsToMatchHeader(final JavaRDD reads, final SAMFileHeader header, final int numReducers) { + final Comparator comparator = getSAMRecordComparator(header); + if ( comparator == null ) { + return reads; + } else { + return SparkUtils.sortUsingElementsAsKeys(reads, comparator, numReducers); + } + } + + //Returns the comparator to use or null if no sorting is required. + private static Comparator getSAMRecordComparator(final SAMFileHeader header) { + switch (header.getSortOrder()){ + case coordinate: return new HeaderlessSAMRecordCoordinateComparator(header); + //duplicate isn't supported because it doesn't work right on headerless SAMRecords + case duplicate: throw new UserException.UnimplementedFeature("The sort order \"duplicate\" is not supported in Spark."); + case queryname: + case unsorted: return header.getSortOrder().getComparatorInstance(); + default: return null; //NOTE: javac warns if you have this (useless) default BUT it errors out if you remove this default. + } + } + }