diff --git a/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala b/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala index 87df7156bb..5150e5d31a 100644 --- a/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala +++ b/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala @@ -59,6 +59,8 @@ class TransformArgs extends Args4jBase with ADAMSaveAnyArgs with ParquetArgs { var useAlignedReadPredicate: Boolean = false @Args4jOption(required = false, name = "-sort_reads", usage = "Sort the reads by referenceId and read position") var sortReads: Boolean = false + @Args4jOption(required = false, name = "-sort_lexicographically", usage = "Sort the reads lexicographically by contig name, instead of by index.") + var sortLexicographically: Boolean = false @Args4jOption(required = false, name = "-mark_duplicate_reads", usage = "Mark duplicate reads") var markDuplicates: Boolean = false @Args4jOption(required = false, name = "-recalibrate_base_qualities", usage = "Recalibrate the base quality scores (ILLUMINA only)") @@ -123,6 +125,7 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans val stringency = ValidationStringency.valueOf(args.stringency) def apply(rdd: RDD[AlignmentRecord], + sd: SequenceDictionary, rgd: RecordGroupDictionary): RDD[AlignmentRecord] = { var adamRecords = rdd @@ -207,7 +210,11 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans } log.info("Sorting reads") - adamRecords = oldRdd.sortReadsByReferencePosition() + if (args.sortLexicographically) { + adamRecords = oldRdd.sortReadsByReferencePosition() + } else { + adamRecords = oldRdd.sortReadsByReferencePositionAndIndex(sd) + } if (args.cache) { oldRdd.unpersist() @@ -329,15 +336,19 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans }) // run our transformation - val outputRdd = this.apply(mergedRdd, mergedRgd) + val outputRdd = this.apply(mergedRdd, mergedSd, mergedRgd) // if we are sorting, we must strip the indices from the sequence dictionary // and sort the sequence dictionary // // we must do this because we do a lexicographic sort, not an index-based sort val sdFinal = if (args.sortReads) { - mergedSd.stripIndices - .sorted + if (args.sortLexicographically) { + mergedSd.stripIndices + .sorted + } else { + mergedSd + } } else { mergedSd } diff --git a/adam-cli/src/test/scala/org/bdgenomics/adam/cli/TransformSuite.scala b/adam-cli/src/test/scala/org/bdgenomics/adam/cli/TransformSuite.scala index ff66e333b1..a973689b91 100644 --- a/adam-cli/src/test/scala/org/bdgenomics/adam/cli/TransformSuite.scala +++ b/adam-cli/src/test/scala/org/bdgenomics/adam/cli/TransformSuite.scala @@ -34,7 +34,7 @@ class TransformSuite extends ADAMFunSuite { val inputPath = copyResource("unordered.sam") val actualPath = tmpFile("ordered.sam") val expectedPath = copyResource("ordered.sam") - Transform(Array("-single", "-sort_reads", inputPath, actualPath)).run(sc) + Transform(Array("-single", "-sort_reads", "-sort_lexicographically", inputPath, actualPath)).run(sc) checkFiles(expectedPath, actualPath) } @@ -54,7 +54,7 @@ class TransformSuite extends ADAMFunSuite { val actualPath = tmpFile("ordered.sam") val expectedPath = copyResource("ordered.sam") Transform(Array(inputPath, intermediateAdamPath)).run(sc) - Transform(Array("-single", "-sort_reads", intermediateAdamPath, actualPath)).run(sc) + Transform(Array("-single", "-sort_reads", "-sort_lexicographically", intermediateAdamPath, actualPath)).run(sc) checkFiles(expectedPath, actualPath) } } diff --git a/adam-cli/src/test/scala/org/bdgenomics/adam/cli/ViewSuite.scala b/adam-cli/src/test/scala/org/bdgenomics/adam/cli/ViewSuite.scala index c68ea379ef..75d656ef5c 100644 --- a/adam-cli/src/test/scala/org/bdgenomics/adam/cli/ViewSuite.scala +++ b/adam-cli/src/test/scala/org/bdgenomics/adam/cli/ViewSuite.scala @@ -18,8 +18,9 @@ package org.bdgenomics.adam.cli import org.apache.spark.rdd.RDD -import org.bdgenomics.adam.util.ADAMFunSuite +import org.bdgenomics.adam.models.SequenceDictionary import org.bdgenomics.adam.rdd.ADAMContext._ +import org.bdgenomics.adam.util.ADAMFunSuite import org.bdgenomics.formats.avro.AlignmentRecord import org.bdgenomics.utils.cli.Args4j @@ -46,7 +47,7 @@ class ViewSuite extends ADAMFunSuite { val rdd = aRdd.rdd val rgd = aRdd.recordGroups - reads = transform.apply(rdd, rgd).collect() + reads = transform.apply(rdd, SequenceDictionary.empty, rgd).collect() readsCount = reads.size.toInt } diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/instrumentation/Timers.scala b/adam-core/src/main/scala/org/bdgenomics/adam/instrumentation/Timers.scala index be8a41bf33..91874c6f67 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/instrumentation/Timers.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/instrumentation/Timers.scala @@ -69,6 +69,7 @@ object Timers extends Metrics { // Sort Reads val SortReads = timer("Sort Reads") + val SortByIndex = timer("Sort Reads By Index") // File Saving val SAMSave = timer("SAM Save") diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctions.scala b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctions.scala index c51205d0e1..32fc4da18d 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctions.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctions.scala @@ -44,6 +44,7 @@ import org.bdgenomics.utils.misc.Logging import org.seqdoop.hadoop_bam.SAMRecordWritable import scala.annotation.tailrec import scala.language.implicitConversions +import scala.math.{ abs, min } import scala.reflect.ClassTag private[rdd] class AlignmentRecordRDDFunctions(val rdd: RDD[AlignmentRecord]) @@ -571,13 +572,37 @@ private[rdd] class AlignmentRecordRDDFunctions(val rdd: RDD[AlignmentRecord]) // we sort the unmapped reads by read name. We prefix with tildes ("~"; // ASCII 126) to ensure that the read name is lexicographically "after" the // contig names. - rdd.keyBy(r => { + rdd.sortBy(r => { if (r.getReadMapped) { ReferencePosition(r) } else { ReferencePosition(s"~~~${r.getReadName}", 0) } - }).sortByKey().map(_._2) + }) + } + + def sortReadsByReferencePositionAndIndex(sd: SequenceDictionary): RDD[AlignmentRecord] = SortByIndex.time { + log.info("Sorting reads by reference index, using %s.".format(sd)) + + import scala.math.Ordering.{ Int => ImplicitIntOrdering, _ } + + // NOTE: In order to keep unmapped reads from swamping a single partition + // we sort the unmapped reads by read name. To do this, we hash the sequence name + // and add the max contig index + val maxContigIndex = sd.records.flatMap(_.referenceIndex).max + rdd.sortBy(r => { + if (r.getReadMapped) { + val sr = sd(r.getContigName) + require(sr.isDefined, "Read %s has contig name %s not in dictionary %s.".format( + r, r.getContigName, sd)) + require(sr.get.referenceIndex.isDefined, + "Contig %s from sequence dictionary lacks an index.".format(sr)) + + (sr.get.referenceIndex.get, r.getStart: Long) + } else { + (min(abs(r.getReadName.hashCode + maxContigIndex), Int.MaxValue), 0L) + } + }) } /** diff --git a/adam-core/src/test/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctionsSuite.scala b/adam-core/src/test/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctionsSuite.scala index 3301504936..d483b07738 100644 --- a/adam-core/src/test/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctionsSuite.scala +++ b/adam-core/src/test/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctionsSuite.scala @@ -22,7 +22,11 @@ import java.nio.file.Files import htsjdk.samtools.ValidationStringency import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.bdgenomics.adam.models.{ RecordGroupDictionary, SequenceDictionary } +import org.bdgenomics.adam.models.{ + RecordGroupDictionary, + SequenceDictionary, + SequenceRecord +} import org.bdgenomics.adam.rdd.ADAMContext._ import org.bdgenomics.adam.rdd.TestSaveArgs import org.bdgenomics.adam.util.ADAMFunSuite @@ -30,6 +34,17 @@ import org.bdgenomics.formats.avro._ import scala.io.Source import scala.util.Random +private object SequenceIndexWithReadOrdering extends Ordering[((Int, Long), (AlignmentRecord, Int))] { + def compare(a: ((Int, Long), (AlignmentRecord, Int)), + b: ((Int, Long), (AlignmentRecord, Int))): Int = { + if (a._1._1 == b._1._1) { + a._1._2.compareTo(b._1._2) + } else { + a._1._1.compareTo(b._1._1) + } + } +} + class AlignmentRecordRDDFunctionsSuite extends ADAMFunSuite { sparkTest("sorting reads") { @@ -39,11 +54,9 @@ class AlignmentRecordRDDFunctionsSuite extends ADAMFunSuite { val mapped = random.nextBoolean() val builder = AlignmentRecord.newBuilder().setReadMapped(mapped) if (mapped) { - val contig = Contig.newBuilder - .setContigName(random.nextInt(numReadsToCreate / 10).toString) - .build + val contigName = random.nextInt(numReadsToCreate / 10).toString val start = random.nextInt(1000000) - builder.setContigName(contig.getContigName).setStart(start).setEnd(start) + builder.setContigName(contigName).setStart(start).setEnd(start) } builder.setReadName((0 until 20).map(i => (random.nextInt(100) + 64)).mkString) builder.build() @@ -59,6 +72,50 @@ class AlignmentRecordRDDFunctionsSuite extends ADAMFunSuite { assert(expectedSortedReads === mapped) } + sparkTest("sorting reads by reference index") { + val random = new Random("sortingIndices".hashCode) + val numReadsToCreate = 1000 + val reads = for (i <- 0 until numReadsToCreate) yield { + val mapped = random.nextBoolean() + val builder = AlignmentRecord.newBuilder().setReadMapped(mapped) + if (mapped) { + val contigName = random.nextInt(numReadsToCreate / 10).toString + val start = random.nextInt(1000000) + builder.setContigName(contigName).setStart(start).setEnd(start) + } + builder.setReadName((0 until 20).map(i => (random.nextInt(100) + 64)).mkString) + builder.build() + } + val contigNames = reads.filter(_.getReadMapped).map(_.getContigName).toSet + val sd = new SequenceDictionary(contigNames.toSeq + .zipWithIndex + .map(kv => { + val (name, index) = kv + SequenceRecord(name, Int.MaxValue, referenceIndex = Some(index)) + }).toVector) + + val rdd = sc.parallelize(reads) + val sortedReads = rdd.sortReadsByReferencePositionAndIndex(sd).collect().zipWithIndex + val (mapped, unmapped) = sortedReads.partition(_._1.getReadMapped) + + // Make sure that all the unmapped reads are placed at the end + assert(unmapped.forall(p => p._2 > mapped.takeRight(1)(0)._2)) + + def toIndex(r: AlignmentRecord): Int = { + sd(r.getContigName).get.referenceIndex.get + } + + // Make sure that we appropriately sorted the reads + import scala.math.Ordering._ + val expectedSortedReads = mapped.map(kv => { + val (r, idx) = kv + val start: Long = r.getStart + ((toIndex(r), start), (r, idx)) + }).sortBy(_._1) + .map(_._2) + assert(expectedSortedReads === mapped) + } + sparkTest("characterizeTags counts integer tag values correctly") { val tagCounts: Map[String, Long] = Map("XT" -> 10L, "XU" -> 9L, "XV" -> 8L) val readItr: Iterable[AlignmentRecord] =