Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes and refactoring methods in SparkUtils #4765

Merged
merged 3 commits into from
May 18, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.GATKReadToBDGAlignmentRecordConverter;
import org.broadinstitute.hellbender.utils.read.HeaderlessSAMRecordCoordinateComparator;
import org.broadinstitute.hellbender.utils.read.ReadsWriteFormat;
import org.broadinstitute.hellbender.utils.spark.SparkUtils;
import org.seqdoop.hadoop_bam.*;
Expand All @@ -34,6 +35,7 @@

import java.io.File;
import java.io.IOException;
import java.util.Comparator;

/**
* ReadsSparkSink writes GATKReads to a file. This code lifts from the HadoopGenomics/Hadoop-BAM
Expand All @@ -43,6 +45,36 @@ public final class ReadsSparkSink {

private final static Logger logger = LogManager.getLogger(ReadsSparkSink.class);

/**
* Sorts the given reads according to the sort order in the header.
* @param reads the reads to sort
* @param header the header specifying the sort order,
* if the header specifies {@link SAMFileHeader.SortOrder#unsorted} or {@link SAMFileHeader.SortOrder#unknown}
* then no sort will be performed
* @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers
* @return a sorted RDD of reads
*/
public static JavaRDD<SAMRecord> sortSamRecordsToMatchHeader(final JavaRDD<SAMRecord> reads, final SAMFileHeader header, final int numReducers) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to worry about the repartitioning for SamRecords here? It might be just a hair too much especially because we would have to rewrite our partitioner somewhere... hmm...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was not meant to be a public method...

final Comparator<SAMRecord> comparator = getSAMRecordComparator(header);
if ( comparator == null ) {
return reads;
} else {
return SparkUtils.sortUsingElementsAsKeys(reads, comparator, numReducers);
}
}

//Returns the comparator to use or null if no sorting is required.
private static Comparator<SAMRecord> getSAMRecordComparator(final SAMFileHeader header) {
switch (header.getSortOrder()){
case coordinate: return new HeaderlessSAMRecordCoordinateComparator(header);
//duplicate isn't supported because it doesn't work right on headerless SAMRecords
case duplicate: throw new UserException.UnimplementedFeature("The sort order \"duplicate\" is not supported in Spark.");
case queryname:
case unsorted: return header.getSortOrder().getComparatorInstance();
default: return null; //NOTE: javac warns if you have this (useless) default BUT it errors out if you remove this default.
}
}

// Output format class for writing BAM files through saveAsNewAPIHadoopFile. Must be public.
public static class SparkBAMOutputFormat extends KeyIgnoringBAMOutputFormat<NullWritable> {
public static SAMFileHeader bamHeader = null;
Expand Down Expand Up @@ -284,7 +316,7 @@ private static void writeReadsSingle(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final SAMFormat samOutputFormat, final JavaRDD<SAMRecord> reads,
final SAMFileHeader header, final int numReducers, final String outputPartsDir) throws IOException {

final JavaRDD<SAMRecord> sortedReads = SparkUtils.sortSamRecordsToMatchHeader(reads, header, numReducers);
final JavaRDD<SAMRecord> sortedReads = sortSamRecordsToMatchHeader(reads, header, numReducers);
final String outputPartsDirectory = (outputPartsDir == null)? getDefaultPartsDirectory(outputFile) : outputPartsDir;
saveAsShardedHadoopFiles(ctx, outputPartsDirectory, referenceFile, samOutputFormat, sortedReads, header, false);
logger.info("Finished sorting the bam file and dumping read shards to disk, proceeding to merge the shards into a single file using the master thread");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public static void callVariantsWithHaplotypeCallerAndWriteOutput(
// Reads must be coordinate sorted to use the overlaps partitioner
final SAMFileHeader readsHeader = header.clone();
readsHeader.setSortOrder(SAMFileHeader.SortOrder.coordinate);
final JavaRDD<GATKRead> coordinateSortedReads = SparkUtils.coordinateSortReads(reads, readsHeader, numReducers);
final JavaRDD<GATKRead> coordinateSortedReads = SparkUtils.sortReadsAccordingToHeader(reads, readsHeader, numReducers);

final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgs, false, false, readsHeader, new ReferenceMultiSourceAdapter(reference));
final JavaRDD<VariantContext> variants = callVariantsWithHaplotypeCaller(ctx, coordinateSortedReads, readsHeader, reference, intervals, hcArgs, shardingArgs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ protected void runTool(final JavaSparkContext ctx) {
// the overlaps partitioner requires that reads are coordinate-sorted
final SAMFileHeader readsHeader = header.clone();
readsHeader.setSortOrder(SAMFileHeader.SortOrder.coordinate);
markedFilteredReadsForBQSR = SparkUtils.coordinateSortReads(markedFilteredReadsForBQSR, readsHeader, numReducers);
markedFilteredReadsForBQSR = SparkUtils.sortReadsAccordingToHeader(markedFilteredReadsForBQSR, readsHeader, numReducers);
}

VariantsSparkSource variantsSparkSource = new VariantsSparkSource(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,12 @@ public static JavaRDD<GATKRead> mark(final JavaRDD<GATKRead> reads, final SAMFil
final MarkDuplicatesScoringStrategy scoringStrategy,
final OpticalDuplicateFinder opticalDuplicateFinder,
final int numReducers, final boolean dontMarkUnmappedMates) {
JavaRDD<GATKRead> sortedReadsForMarking;
SAMFileHeader headerForTool = header.clone();

// If the input isn't queryname sorted, sort it before duplicate marking
sortedReadsForMarking = querynameSortReadsIfNecessary(reads, numReducers, headerForTool);
final JavaRDD<GATKRead> sortedReadsForMarking = querynameSortReadsIfNecessary(reads, numReducers, headerForTool);

JavaPairRDD<MarkDuplicatesSparkUtils.IndexPair<String>, Integer> namesOfNonDuplicates = MarkDuplicatesSparkUtils.transformToDuplicateNames(headerForTool, scoringStrategy, opticalDuplicateFinder, sortedReadsForMarking, numReducers);
final JavaPairRDD<MarkDuplicatesSparkUtils.IndexPair<String>, Integer> namesOfNonDuplicates = MarkDuplicatesSparkUtils.transformToDuplicateNames(headerForTool, scoringStrategy, opticalDuplicateFinder, sortedReadsForMarking, numReducers);

// Here we explicitly repartition the read names of the unmarked reads to match the partitioning of the original bam
final JavaRDD<Tuple2<String,Integer>> repartitionedReadNames = namesOfNonDuplicates
Expand Down Expand Up @@ -139,7 +138,7 @@ private static JavaRDD<GATKRead> querynameSortReadsIfNecessary(JavaRDD<GATKRead>
sortedReadsForMarking = reads;
} else {
headerForTool.setSortOrder(SAMFileHeader.SortOrder.queryname);
sortedReadsForMarking = SparkUtils.querynameSortReads(reads, headerForTool, numReducers);
sortedReadsForMarking = SparkUtils.sortReadsAccordingToHeader(reads, headerForTool, numReducers);
}
return sortedReadsForMarking;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
import com.google.common.collect.Iterators;
import com.google.common.collect.PeekingIterator;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.SAMTextHeaderCodec;
import htsjdk.samtools.util.BinaryCodec;
import htsjdk.samtools.util.BlockCompressedOutputStream;
import htsjdk.samtools.util.BlockCompressedStreamConstants;
import htsjdk.samtools.util.RuntimeIOException;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
Expand All @@ -22,11 +21,10 @@
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.read.*;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.*;
import scala.Tuple2;


import java.io.*;
import java.util.ArrayList;
import java.util.Comparator;
Expand Down Expand Up @@ -127,73 +125,39 @@ public static boolean pathExists(final JavaSparkContext ctx, final Path targetPa
}

/**
* Sorts the given reads in coordinate sort order.
* @param reads the reads to sort
* @param header the reads header, which must specify coordinate sort order
* @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers
* @return a sorted RDD of reads
* Do a total sort of an RDD of {@link GATKRead} according to the sort order in the header.
* @param reads a JavaRDD of reads which may or may not be sorted
* @param header a header which specifies the desired new sort order.
* Only {@link SAMFileHeader.SortOrder#coordinate} and {@link SAMFileHeader.SortOrder#queryname} are supported.
* All others will result in {@link GATKException}
* @param numReducers number of reducers to use when sorting
* @return a new JavaRDD or reads which is globally sorted in a way that is consistent with the sort order given in the header
*/
public static JavaRDD<GATKRead> coordinateSortReads(final JavaRDD<GATKRead> reads, final SAMFileHeader header, final int numReducers) {
Utils.validate(header.getSortOrder().equals(SAMFileHeader.SortOrder.coordinate), "Header must specify coordinate sort order, but was" + header.getSortOrder());

return sort(reads, new ReadCoordinateComparator(header), numReducers);
}

/**
* Sorts the given reads in queryname sort order.
* This guarantees that all reads that have the same read name are placed into the same partition
* @param reads the reads to sort
* @param header header of the bam, the header is required to have been set to queryname order
* @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers
* @return a sorted RDD of reads
*/
public static JavaRDD<GATKRead> querynameSortReads(final JavaRDD<GATKRead> reads, SAMFileHeader header, final int numReducers) {
Utils.validate(header.getSortOrder().equals(SAMFileHeader.SortOrder.queryname), "Header must specify queryname sort order, but was " + header.getSortOrder());

final JavaRDD<GATKRead> sortedReads = sort(reads, new ReadQueryNameComparator(), numReducers);
return putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext(reads.context()));
}

/**
* Sorts the given reads according to the sort order in the header.
* @param reads the reads to sort
* @param header the header specifying the sort order,
* if the header specifies {@link SAMFileHeader.SortOrder#unsorted} or {@link SAMFileHeader.SortOrder#unknown}
* then no sort will be performed
* @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers
* @return a sorted RDD of reads
*/
public static JavaRDD<SAMRecord> sortSamRecordsToMatchHeader(final JavaRDD<SAMRecord> reads, final SAMFileHeader header, final int numReducers) {
final Comparator<SAMRecord> comparator = getSAMRecordComparator(header);
if ( comparator == null ) {
return reads;
} else {
return sort(reads, comparator, numReducers);
}
}

//Returns the comparator to use or null if no sorting is required.
private static Comparator<SAMRecord> getSAMRecordComparator(final SAMFileHeader header) {
switch (header.getSortOrder()){
case coordinate: return new HeaderlessSAMRecordCoordinateComparator(header);
//duplicate isn't supported because it doesn't work right on headerless SAMRecords
case duplicate: throw new UserException.UnimplementedFeature("The sort order \"duplicate\" is not supported in Spark.");
public static JavaRDD<GATKRead> sortReadsAccordingToHeader(final JavaRDD<GATKRead> reads, final SAMFileHeader header, final int numReducers){
final SAMFileHeader.SortOrder order = header.getSortOrder();
switch (order){
case coordinate:
return sortUsingElementsAsKeys(reads, new ReadCoordinateComparator(header), numReducers);
case queryname:
case unsorted: return header.getSortOrder().getComparatorInstance();
default: return null; //NOTE: javac warns if you have this (useless) default BUT it errors out if you remove this default.
final JavaRDD<GATKRead> sortedReads = sortUsingElementsAsKeys(reads, new ReadQueryNameComparator(), numReducers);
return putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext(reads.context()));
default:
throw new GATKException("Sort order: " + order + " is not supported.");
}
}

/**
* do a total sort of an RDD so that all the elements in partition i are less than those in partition i+1 according to the given comparator
* Do a global sort of an RDD using the given comparator.
* This method uses the RDD elements themselves as the keys in the spark key/value sort. This may be inefficient
* if the comparator only uses looks at a small fraction of the element to perform the comparison.
*/
private static <T> JavaRDD<T> sort(JavaRDD<T> reads, Comparator<T> comparator, int numReducers) {
public static <T> JavaRDD<T> sortUsingElementsAsKeys(JavaRDD<T> elements, Comparator<T> comparator, int numReducers) {
Utils.nonNull(comparator);
Utils.nonNull(reads);
Utils.nonNull(elements);

// Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount
// of data going through the shuffle.
final JavaPairRDD<T, Void> rddReadPairs = reads.mapToPair(read -> new Tuple2<>(read, (Void) null));
final JavaPairRDD<T, Void> rddReadPairs = elements.mapToPair(read -> new Tuple2<>(read, (Void) null));

final JavaPairRDD<T, Void> readVoidPairs;
if (numReducers > 0) {
Expand All @@ -205,17 +169,20 @@ private static <T> JavaRDD<T> sort(JavaRDD<T> reads, Comparator<T> comparator, i
}

/**
* Ensure all reads with the same name appear in the same partition.
* Requires that the No shuffle is needed.

* Ensure all reads with the same name appear in the same partition of a queryname sorted RDD.
* This avoids a global shuffle and only transfers the leading elements from each partition which is fast in most
* cases.
*
* The RDD must be queryname sorted. If there are so many reads with the same name that they span multiple partitions
* this will throw {@link GATKException}.
*/
public static JavaRDD<GATKRead> putReadsWithTheSameNameInTheSamePartition(final SAMFileHeader header, final JavaRDD<GATKRead> reads, final JavaSparkContext ctx) {
Utils.validateArg(ReadUtils.isReadNameGroupedBam(header), () -> "Reads must be queryname grouped or sorted. " +
"Actual sort:" + header.getSortOrder() + " Actual grouping:" +header.getGroupOrder());
int numPartitions = reads.getNumPartitions();
final String firstGroupInBam = reads.first().getName();
final String firstNameInBam = reads.first().getName();
// Find the first group in each partition
List<List<GATKRead>> firstReadNamesInEachPartition = reads
List<List<GATKRead>> firstReadNameGroupInEachPartition = reads
.mapPartitions(it -> { PeekingIterator<GATKRead> current = Iterators.peekingIterator(it);
List<GATKRead> firstGroup = new ArrayList<>(2);
firstGroup.add(current.next());
Expand All @@ -229,7 +196,7 @@ public static JavaRDD<GATKRead> putReadsWithTheSameNameInTheSamePartition(final

// Checking for pathological cases (read name groups that span more than 2 partitions)
String groupName = null;
for (List<GATKRead> group : firstReadNamesInEachPartition) {
for (List<GATKRead> group : firstReadNameGroupInEachPartition) {
if (group!=null && !group.isEmpty()) {
// If a read spans multiple partitions we expect its name to show up multiple times and we don't expect this to work properly
if (groupName != null && group.get(0).getName().equals(groupName)) {
Expand All @@ -241,17 +208,17 @@ public static JavaRDD<GATKRead> putReadsWithTheSameNameInTheSamePartition(final
}

// Shift left, so that each partition will be joined with the first read group from the _next_ partition
List<List<GATKRead>> firstReadInNextPartition = new ArrayList<>(firstReadNamesInEachPartition.subList(1, numPartitions));
List<List<GATKRead>> firstReadInNextPartition = new ArrayList<>(firstReadNameGroupInEachPartition.subList(1, numPartitions));
firstReadInNextPartition.add(null); // the last partition does not have any reads to add to it

// Join the reads with the first read from the _next_ partition, then filter out the first and/or last read if not in a pair
// Join the reads with the first read from the _next_ partition, then filter out the first reads in this partition
return reads.zipPartitions(ctx.parallelize(firstReadInNextPartition, numPartitions),
(FlatMapFunction2<Iterator<GATKRead>, Iterator<List<GATKRead>>, GATKRead>) (it1, it2) -> {
PeekingIterator<GATKRead> current = Iterators.peekingIterator(it1);
String firstName = current.peek().getName();
// Make sure we don't remove reads from the first partition
if (!firstGroupInBam.equals(firstName)) {
// skip the first read name group in the _current_ partition if it is the second in a pair since it will be handled in the previous partition
if (!firstNameInBam.equals(firstName)) {
// skip the first read name group in the _current_ partition since it will be handled in the previous partition
while (current.hasNext() && current.peek() != null && current.peek().getName().equals(firstName)) {
current.next();
}
Expand Down
Loading