Skip to content

Commit

Permalink
responding to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lbergelson committed May 17, 2018
1 parent 63f22f1 commit 16e6179
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.GATKReadToBDGAlignmentRecordConverter;
import org.broadinstitute.hellbender.utils.read.HeaderlessSAMRecordCoordinateComparator;
import org.broadinstitute.hellbender.utils.read.ReadsWriteFormat;
import org.broadinstitute.hellbender.utils.spark.SparkUtils;
import org.seqdoop.hadoop_bam.*;
Expand All @@ -34,6 +35,7 @@

import java.io.File;
import java.io.IOException;
import java.util.Comparator;

/**
* ReadsSparkSink writes GATKReads to a file. This code lifts from the HadoopGenomics/Hadoop-BAM
Expand All @@ -43,6 +45,36 @@ public final class ReadsSparkSink {

private final static Logger logger = LogManager.getLogger(ReadsSparkSink.class);

/**
* Sorts the given reads according to the sort order in the header.
* @param reads the reads to sort
* @param header the header specifying the sort order,
* if the header specifies {@link SAMFileHeader.SortOrder#unsorted} or {@link SAMFileHeader.SortOrder#unknown}
* then no sort will be performed
* @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers
* @return a sorted RDD of reads
*/
public static JavaRDD<SAMRecord> sortSamRecordsToMatchHeader(final JavaRDD<SAMRecord> reads, final SAMFileHeader header, final int numReducers) {
final Comparator<SAMRecord> comparator = getSAMRecordComparator(header);
if ( comparator == null ) {
return reads;
} else {
return SparkUtils.sortUsingElementsAsKeys(reads, comparator, numReducers);
}
}

//Returns the comparator to use or null if no sorting is required.
private static Comparator<SAMRecord> getSAMRecordComparator(final SAMFileHeader header) {
switch (header.getSortOrder()){
case coordinate: return new HeaderlessSAMRecordCoordinateComparator(header);
//duplicate isn't supported because it doesn't work right on headerless SAMRecords
case duplicate: throw new UserException.UnimplementedFeature("The sort order \"duplicate\" is not supported in Spark.");
case queryname:
case unsorted: return header.getSortOrder().getComparatorInstance();
default: return null; //NOTE: javac warns if you have this (useless) default BUT it errors out if you remove this default.
}
}

// Output format class for writing BAM files through saveAsNewAPIHadoopFile. Must be public.
public static class SparkBAMOutputFormat extends KeyIgnoringBAMOutputFormat<NullWritable> {
public static SAMFileHeader bamHeader = null;
Expand Down Expand Up @@ -284,7 +316,7 @@ private static void writeReadsSingle(
final JavaSparkContext ctx, final String outputFile, final String referenceFile, final SAMFormat samOutputFormat, final JavaRDD<SAMRecord> reads,
final SAMFileHeader header, final int numReducers, final String outputPartsDir) throws IOException {

final JavaRDD<SAMRecord> sortedReads = SparkUtils.sortSamRecordsToMatchHeader(reads, header, numReducers);
final JavaRDD<SAMRecord> sortedReads = sortSamRecordsToMatchHeader(reads, header, numReducers);
final String outputPartsDirectory = (outputPartsDir == null)? getDefaultPartsDirectory(outputFile) : outputPartsDir;
saveAsShardedHadoopFiles(ctx, outputPartsDirectory, referenceFile, samOutputFormat, sortedReads, header, false);
logger.info("Finished sorting the bam file and dumping read shards to disk, proceeding to merge the shards into a single file using the master thread");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public static void callVariantsWithHaplotypeCallerAndWriteOutput(
// Reads must be coordinate sorted to use the overlaps partitioner
final SAMFileHeader readsHeader = header.clone();
readsHeader.setSortOrder(SAMFileHeader.SortOrder.coordinate);
final JavaRDD<GATKRead> coordinateSortedReads = SparkUtils.coordinateSortReads(reads, readsHeader, numReducers);
final JavaRDD<GATKRead> coordinateSortedReads = SparkUtils.sortReadsAccordingToHeader(reads, readsHeader, numReducers);

final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgs, false, false, readsHeader, new ReferenceMultiSourceAdapter(reference));
final JavaRDD<VariantContext> variants = callVariantsWithHaplotypeCaller(ctx, coordinateSortedReads, readsHeader, reference, intervals, hcArgs, shardingArgs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ protected void runTool(final JavaSparkContext ctx) {
// the overlaps partitioner requires that reads are coordinate-sorted
final SAMFileHeader readsHeader = header.clone();
readsHeader.setSortOrder(SAMFileHeader.SortOrder.coordinate);
markedFilteredReadsForBQSR = SparkUtils.coordinateSortReads(markedFilteredReadsForBQSR, readsHeader, numReducers);
markedFilteredReadsForBQSR = SparkUtils.sortReadsAccordingToHeader(markedFilteredReadsForBQSR, readsHeader, numReducers);
}

VariantsSparkSource variantsSparkSource = new VariantsSparkSource(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,12 @@ public static JavaRDD<GATKRead> mark(final JavaRDD<GATKRead> reads, final SAMFil
final MarkDuplicatesScoringStrategy scoringStrategy,
final OpticalDuplicateFinder opticalDuplicateFinder,
final int numReducers, final boolean dontMarkUnmappedMates) {
JavaRDD<GATKRead> sortedReadsForMarking;
SAMFileHeader headerForTool = header.clone();

// If the input isn't queryname sorted, sort it before duplicate marking
sortedReadsForMarking = querynameSortReadsIfNecessary(reads, numReducers, headerForTool);
final JavaRDD<GATKRead> sortedReadsForMarking = querynameSortReadsIfNecessary(reads, numReducers, headerForTool);

JavaPairRDD<MarkDuplicatesSparkUtils.IndexPair<String>, Integer> namesOfNonDuplicates = MarkDuplicatesSparkUtils.transformToDuplicateNames(headerForTool, scoringStrategy, opticalDuplicateFinder, sortedReadsForMarking, numReducers);
final JavaPairRDD<MarkDuplicatesSparkUtils.IndexPair<String>, Integer> namesOfNonDuplicates = MarkDuplicatesSparkUtils.transformToDuplicateNames(headerForTool, scoringStrategy, opticalDuplicateFinder, sortedReadsForMarking, numReducers);

// Here we explicitly repartition the read names of the unmarked reads to match the partitioning of the original bam
final JavaRDD<Tuple2<String,Integer>> repartitionedReadNames = namesOfNonDuplicates
Expand Down Expand Up @@ -139,7 +138,7 @@ private static JavaRDD<GATKRead> querynameSortReadsIfNecessary(JavaRDD<GATKRead>
sortedReadsForMarking = reads;
} else {
headerForTool.setSortOrder(SAMFileHeader.SortOrder.queryname);
sortedReadsForMarking = SparkUtils.querynameSortReads(reads, headerForTool, numReducers);
sortedReadsForMarking = SparkUtils.sortReadsAccordingToHeader(reads, headerForTool, numReducers);
}
return sortedReadsForMarking;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
import com.google.common.collect.Iterators;
import com.google.common.collect.PeekingIterator;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.SAMTextHeaderCodec;
import htsjdk.samtools.util.BinaryCodec;
import htsjdk.samtools.util.BlockCompressedOutputStream;
import htsjdk.samtools.util.BlockCompressedStreamConstants;
import htsjdk.samtools.util.RuntimeIOException;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
Expand All @@ -22,11 +21,10 @@
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.read.*;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.*;
import scala.Tuple2;


import java.io.*;
import java.util.ArrayList;
import java.util.Comparator;
Expand Down Expand Up @@ -127,73 +125,39 @@ public static boolean pathExists(final JavaSparkContext ctx, final Path targetPa
}

/**
* Sorts the given reads in coordinate sort order.
* @param reads the reads to sort
* @param header the reads header, which must specify coordinate sort order
* @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers
* @return a sorted RDD of reads
* Do a total sort of an RDD of {@link GATKRead} according to the sort order in the header.
* @param reads a JavaRDD of reads which may or may not be sorted
* @param header a header which specifies the desired new sort order.
* Only {@link SAMFileHeader.SortOrder#coordinate} and {@link SAMFileHeader.SortOrder#queryname} are supported.
* All others will result in {@link GATKException}
* @param numReducers number of reducers to use when sorting
* @return a new JavaRDD or reads which is globally sorted in a way that is consistent with the sort order given in the header
*/
public static JavaRDD<GATKRead> coordinateSortReads(final JavaRDD<GATKRead> reads, final SAMFileHeader header, final int numReducers) {
Utils.validate(header.getSortOrder().equals(SAMFileHeader.SortOrder.coordinate), "Header must specify coordinate sort order, but was" + header.getSortOrder());

return sort(reads, new ReadCoordinateComparator(header), numReducers);
}

/**
* Sorts the given reads in queryname sort order.
* This guarantees that all reads that have the same read name are placed into the same partition
* @param reads the reads to sort
* @param header header of the bam, the header is required to have been set to queryname order
* @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers
* @return a sorted RDD of reads
*/
public static JavaRDD<GATKRead> querynameSortReads(final JavaRDD<GATKRead> reads, SAMFileHeader header, final int numReducers) {
Utils.validate(header.getSortOrder().equals(SAMFileHeader.SortOrder.queryname), "Header must specify queryname sort order, but was " + header.getSortOrder());

final JavaRDD<GATKRead> sortedReads = sort(reads, new ReadQueryNameComparator(), numReducers);
return putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext(reads.context()));
}

/**
* Sorts the given reads according to the sort order in the header.
* @param reads the reads to sort
* @param header the header specifying the sort order,
* if the header specifies {@link SAMFileHeader.SortOrder#unsorted} or {@link SAMFileHeader.SortOrder#unknown}
* then no sort will be performed
* @param numReducers the number of reducers to use; a value of 0 means use the default number of reducers
* @return a sorted RDD of reads
*/
public static JavaRDD<SAMRecord> sortSamRecordsToMatchHeader(final JavaRDD<SAMRecord> reads, final SAMFileHeader header, final int numReducers) {
final Comparator<SAMRecord> comparator = getSAMRecordComparator(header);
if ( comparator == null ) {
return reads;
} else {
return sort(reads, comparator, numReducers);
}
}

//Returns the comparator to use or null if no sorting is required.
private static Comparator<SAMRecord> getSAMRecordComparator(final SAMFileHeader header) {
switch (header.getSortOrder()){
case coordinate: return new HeaderlessSAMRecordCoordinateComparator(header);
//duplicate isn't supported because it doesn't work right on headerless SAMRecords
case duplicate: throw new UserException.UnimplementedFeature("The sort order \"duplicate\" is not supported in Spark.");
public static JavaRDD<GATKRead> sortReadsAccordingToHeader(final JavaRDD<GATKRead> reads, final SAMFileHeader header, final int numReducers){
final SAMFileHeader.SortOrder order = header.getSortOrder();
switch (order){
case coordinate:
return sortUsingElementsAsKeys(reads, new ReadCoordinateComparator(header), numReducers);
case queryname:
case unsorted: return header.getSortOrder().getComparatorInstance();
default: return null; //NOTE: javac warns if you have this (useless) default BUT it errors out if you remove this default.
final JavaRDD<GATKRead> sortedReads = sortUsingElementsAsKeys(reads, new ReadQueryNameComparator(), numReducers);
return putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext(reads.context()));
default:
throw new GATKException("Sort order: " + order + " is not supported.");
}
}

/**
* do a total sort of an RDD so that all the elements in partition i are less than those in partition i+1 according to the given comparator
* Do a global sort of an RDD using the given comparator.
* This method uses the RDD elements themselves as the keys in the spark key/value sort. This may be inefficient
* if the comparator only uses looks at a small fraction of the element to perform the comparison.
*/
private static <T> JavaRDD<T> sort(JavaRDD<T> reads, Comparator<T> comparator, int numReducers) {
public static <T> JavaRDD<T> sortUsingElementsAsKeys(JavaRDD<T> elements, Comparator<T> comparator, int numReducers) {
Utils.nonNull(comparator);
Utils.nonNull(reads);
Utils.nonNull(elements);

// Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount
// of data going through the shuffle.
final JavaPairRDD<T, Void> rddReadPairs = reads.mapToPair(read -> new Tuple2<>(read, (Void) null));
final JavaPairRDD<T, Void> rddReadPairs = elements.mapToPair(read -> new Tuple2<>(read, (Void) null));

final JavaPairRDD<T, Void> readVoidPairs;
if (numReducers > 0) {
Expand All @@ -205,17 +169,20 @@ private static <T> JavaRDD<T> sort(JavaRDD<T> reads, Comparator<T> comparator, i
}

/**
* Ensure all reads with the same name appear in the same partition.
* Requires that the No shuffle is needed.
* Ensure all reads with the same name appear in the same partition of a queryname sorted RDD.
* This avoids a global shuffle and only transfers the leading elements from each partition which is fast in most
* cases.
*
* The RDD must be queryname sorted. If there are so many reads with the same name that they span multiple partitions
* this will throw {@link GATKException}.
*/
public static JavaRDD<GATKRead> putReadsWithTheSameNameInTheSamePartition(final SAMFileHeader header, final JavaRDD<GATKRead> reads, final JavaSparkContext ctx) {
Utils.validateArg(ReadUtils.isReadNameGroupedBam(header), () -> "Reads must be queryname grouped or sorted. " +
"Actual sort:" + header.getSortOrder() + " Actual grouping:" +header.getGroupOrder());
int numPartitions = reads.getNumPartitions();
final String firstGroupInBam = reads.first().getName();
final String firstNameInBam = reads.first().getName();
// Find the first group in each partition
List<List<GATKRead>> firstReadNamesInEachPartition = reads
List<List<GATKRead>> firstReadNameGroupInEachPartition = reads
.mapPartitions(it -> { PeekingIterator<GATKRead> current = Iterators.peekingIterator(it);
List<GATKRead> firstGroup = new ArrayList<>(2);
firstGroup.add(current.next());
Expand All @@ -229,7 +196,7 @@ public static JavaRDD<GATKRead> putReadsWithTheSameNameInTheSamePartition(final

// Checking for pathological cases (read name groups that span more than 2 partitions)
String groupName = null;
for (List<GATKRead> group : firstReadNamesInEachPartition) {
for (List<GATKRead> group : firstReadNameGroupInEachPartition) {
if (group!=null && !group.isEmpty()) {
// If a read spans multiple partitions we expect its name to show up multiple times and we don't expect this to work properly
if (groupName != null && group.get(0).getName().equals(groupName)) {
Expand All @@ -241,17 +208,17 @@ public static JavaRDD<GATKRead> putReadsWithTheSameNameInTheSamePartition(final
}

// Shift left, so that each partition will be joined with the first read group from the _next_ partition
List<List<GATKRead>> firstReadInNextPartition = new ArrayList<>(firstReadNamesInEachPartition.subList(1, numPartitions));
List<List<GATKRead>> firstReadInNextPartition = new ArrayList<>(firstReadNameGroupInEachPartition.subList(1, numPartitions));
firstReadInNextPartition.add(null); // the last partition does not have any reads to add to it

// Join the reads with the first read from the _next_ partition, then filter out the first and/or last read if not in a pair
// Join the reads with the first read from the _next_ partition, then filter out the first reads in this partition
return reads.zipPartitions(ctx.parallelize(firstReadInNextPartition, numPartitions),
(FlatMapFunction2<Iterator<GATKRead>, Iterator<List<GATKRead>>, GATKRead>) (it1, it2) -> {
PeekingIterator<GATKRead> current = Iterators.peekingIterator(it1);
String firstName = current.peek().getName();
// Make sure we don't remove reads from the first partition
if (!firstGroupInBam.equals(firstName)) {
// skip the first read name group in the _current_ partition if it is the second in a pair since it will be handled in the previous partition
if (!firstNameInBam.equals(firstName)) {
// skip the first read name group in the _current_ partition since it will be handled in the previous partition
while (current.hasNext() && current.peek() != null && current.peek().getName().equals(firstName)) {
current.next();
}
Expand Down
Loading

0 comments on commit 16e6179

Please sign in to comment.