From e162e543f1049908b33299c6caf631cbf05edf70 Mon Sep 17 00:00:00 2001 From: Frank Austin Nothaft Date: Thu, 27 Feb 2014 12:30:21 -0800 Subject: [PATCH] Refactored SequenceDictionary aggregation to eliminate replication between RDD extensions, and to improve speed. Additionally, refactored SequenceDictionary class to reduce GC overhead and to make class immutable. --- .../cs/amplab/adam/cli/ListDict.scala | 3 +- .../adam/models/SequenceDictionary.scala | 84 ++++++++-------- .../cs/amplab/adam/rdd/AdamRDDFunctions.scala | 97 +++++++++++++++---- .../adam/rdd/GenomicRegionPartitioner.scala | 2 +- .../variation/ADAMVariationRDDFunctions.scala | 19 ++-- .../adam/models/SequenceDictionarySuite.scala | 22 +---- 6 files changed, 133 insertions(+), 94 deletions(-) diff --git a/adam-cli/src/main/scala/edu/berkeley/cs/amplab/adam/cli/ListDict.scala b/adam-cli/src/main/scala/edu/berkeley/cs/amplab/adam/cli/ListDict.scala index 893c9a9860..0f3e4f14b9 100644 --- a/adam-cli/src/main/scala/edu/berkeley/cs/amplab/adam/cli/ListDict.scala +++ b/adam-cli/src/main/scala/edu/berkeley/cs/amplab/adam/cli/ListDict.scala @@ -46,7 +46,8 @@ class ListDict(protected val args: ListDictArgs) extends AdamSparkCommand[ListDi ParquetLogger.hadoopLoggerLevel(Level.SEVERE) val dict = sc.adamDictionaryLoad[ADAMRecord](args.inputPath) - dict.records.toList.sortBy(_.id).foreach { + + dict.recordsIn.sortBy(_.id).foreach { rec: SequenceRecord => println("%d\t%s\t%d".format(rec.id, rec.name, rec.length)) } diff --git a/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/models/SequenceDictionary.scala b/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/models/SequenceDictionary.scala index 0964e358ef..85e810493b 100644 --- a/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/models/SequenceDictionary.scala +++ b/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/models/SequenceDictionary.scala @@ -27,7 +27,7 @@ import scala.math.Ordering.Implicits._ * SequenceDictionary contains the (bijective) map between Ints (the referenceId) and Strings (the referenceName) * from the header of a BAM file, or the combined result of multiple such SequenceDictionaries. */ -class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializable { +class SequenceDictionary(val recordsIn: Array[SequenceRecord]) extends Serializable { // Intermediate value used to ensure that no referenceName or referenceId is listed twice with a different // referenceId or referenceName (respectively). Notice the "toSet", which means it's okay to pass an Iterable @@ -48,13 +48,13 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab } // Maps referenceId -> SequenceRecord - private val recordIndices: mutable.Map[Int, SequenceRecord] = + private lazy val recordIndices: mutable.Map[Int, SequenceRecord] = mutable.Map(recordsIn.map { rec => (rec.id, rec) }.toSeq: _*) // Maps referenceName -> SequenceRecord - private val recordNames: mutable.Map[CharSequence, SequenceRecord] = + private lazy val recordNames: mutable.Map[CharSequence, SequenceRecord] = mutable.Map(recordsIn.map { // Call toString explicitly, since otherwise we were picking up an Avro-specific Utf8 value here, // which was making the containsRefName method below fail in a hard-to-understand way. @@ -71,12 +71,18 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab /** * Returns the sequence record associated with a specific contig name. * + * @throws AssertionError Throws assertion error if sequence corresponding to contig name + * is not found. + * * @param name Name to search for. * @return SequenceRecord associated with this record. */ def apply(name: CharSequence): SequenceRecord = { // must explicitly call toString - see note at recordNames creation RE: Avro & Utf8 - recordNames(name.toString) + val rec = recordsIn.find(kv => kv.name.toString == name.toString) + + assert(rec.isDefined, "Could not find key " + name + " in dictionary.") + rec.get } /** @@ -87,7 +93,7 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab */ def containsRefName(name : CharSequence) : Boolean = { // must explicitly call toString - see note at recordNames creation RE: Avro & Utf8 - recordNames.contains(name.toString) + !recordsIn.forall(kv => kv.name.toString != name.toString) } /** @@ -178,7 +184,6 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab /** * See the note to mapTo, above. - * The results of this remap and mapTo should be to produce a "compatible" dictionary, * i.e. for all d1 and d2, * @@ -194,47 +199,34 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab def remapIndex(i: Int): Int = if (idTransform.contains(i)) idTransform(i) else i - SequenceDictionary(idNamePairs.map { + new SequenceDictionary(idNamePairs.map { case (id, name) => recordIndices(id).withReferenceId(remapIndex(id)) - }.toSeq: _*) + }.toArray) } - def records: Seq[SequenceRecord] = recordIndices.values.toSeq - - def +(rec: SequenceRecord): SequenceDictionary = - new SequenceDictionary(recordsIn ++ List(rec)) + def records: Set[SequenceRecord] = recordIndices.values.toSet - def +=(rec: SequenceRecord): SequenceDictionary = { - - recordIndices.put(rec.id, rec) - recordNames.put(rec.name, rec) - this + private[models] def cleanAndMerge(a1: Array[SequenceRecord], + a2: Array[SequenceRecord]): Array[SequenceRecord] = { + val a2filt = a2.filter(k => !a1.contains(k)) + + a1 ++ a2filt } - def ++(dict: SequenceDictionary): SequenceDictionary = - new SequenceDictionary(recordsIn ++ dict.records) - - def ++(recs: Seq[SequenceRecord]): SequenceDictionary = - recs.foldRight(this)((rec, dict) => dict + rec) - - def ++=(recs: Seq[SequenceRecord]): SequenceDictionary = { - recs.foreach { - rec => this += rec + def +(record: SequenceRecord): SequenceDictionary = { + if (recordsIn.contains(record)) { + new SequenceDictionary(recordsIn) + } else { + new SequenceDictionary(recordsIn :+ record) } - this } - def ++=(dict: SequenceDictionary): SequenceDictionary = { - dict.recordIndices.keys.foreach { - idx => { - val newrec = dict.recordIndices(idx) - recordIndices.put(newrec.id, newrec) - recordNames.put(newrec.name, newrec) - } - } - this - } + def ++(dict: SequenceDictionary): SequenceDictionary = + new SequenceDictionary(cleanAndMerge(recordsIn, dict.recordsIn)) + + def ++(recs: Array[SequenceRecord]): SequenceDictionary = + new SequenceDictionary(cleanAndMerge(recordsIn, recs)) /** * Tests whether two dictionaries are compatible, where "compatible" means that @@ -287,7 +279,7 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab object SequenceDictionary { - def apply(recordsIn: SequenceRecord*) = new SequenceDictionary(recordsIn) + def apply(recordsIn: SequenceRecord*) = new SequenceDictionary(recordsIn.toArray) /** * Extracts a SAM sequence dictionary from a SAM file header and returns an @@ -445,7 +437,7 @@ object SequenceRecord { * @param rec The ADAMRecord from which to extract the SequenceRecord entries * @return a list of all SequenceRecord entries derivable from this record. */ - def fromADAMRecord(rec: ADAMRecord): Seq[SequenceRecord] = { + def fromADAMRecord(rec: ADAMRecord): Set[SequenceRecord] = { assert(rec != null, "ADAMRecord was null") @@ -456,29 +448,29 @@ object SequenceRecord { val left = if (rec.getReadMapped) - List(SequenceRecord(rec.getReferenceId, rec.getReferenceName, rec.getReferenceLength, rec.getReferenceUrl)) + Set(SequenceRecord(rec.getReferenceId, rec.getReferenceName, rec.getReferenceLength, rec.getReferenceUrl)) else - List() + Set() val right = if (rec.getMateMapped) - List(SequenceRecord(rec.getMateReferenceId, rec.getMateReference, rec.getMateReferenceLength, rec.getMateReferenceUrl)) + Set(SequenceRecord(rec.getMateReferenceId, rec.getMateReference, rec.getMateReferenceLength, rec.getMateReferenceUrl)) else - List() + Set() left ++ right } else { - List() + Set() } } else { if (rec.getReadMapped) { - List(SequenceRecord(rec.getReferenceId, rec.getReferenceName, rec.getReferenceLength, rec.getReferenceUrl)) + Set(SequenceRecord(rec.getReferenceId, rec.getReferenceName, rec.getReferenceLength, rec.getReferenceUrl)) } else { // If the read isn't mapped, then ignore the fields altogether. - List() + Set() } } } diff --git a/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/AdamRDDFunctions.scala b/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/AdamRDDFunctions.scala index 9787e5a34c..562bdda585 100644 --- a/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/AdamRDDFunctions.scala +++ b/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/AdamRDDFunctions.scala @@ -70,7 +70,76 @@ class AdamRDDFunctions[T <% SpecificRecord : Manifest](rdd: RDD[T]) extends Seri } -class AdamRecordRDDFunctions(rdd: RDD[ADAMRecord]) extends Serializable with Logging { +/** + * A class that provides functions to recover a sequence dictionary from a generic RDD of records. + * + * @tparam T Type contained in this RDD. + * @param rdd RDD over which aggregation is supported. + */ +abstract class AdamSequenceDictionaryRDDAggregator[T](rdd: RDD[T]) extends Serializable with Logging { + /** + * For a single RDD element, returns 0+ sequence record elements. + * + * @param elem Element from which to extract sequence records. + * @return A seq of sequence records. + */ + def getSequenceRecordsFromElement (elem: T): scala.collection.Set[SequenceRecord] + + /** + * Aggregates together a sequence dictionary from the different individual reference sequences + * used in this dataset. + * + * @return A sequence dictionary describing the reference contigs in this dataset. + */ + def adamGetSequenceDictionary (): SequenceDictionary = { + def mergeRecords(l: List[SequenceRecord], rec: T): List[SequenceRecord] = { + val recs = getSequenceRecordsFromElement(rec) + + recs.foldLeft(l)((li: List[SequenceRecord], r: SequenceRecord) => { + if (!li.contains(r)) { + r :: li + } else { + li + } + }) + } + + def foldIterator(iter: Iterator[T]): SequenceDictionary = { + val recs = iter.foldLeft(List[SequenceRecord]())(mergeRecords) + new SequenceDictionary(recs.toArray) + } + + rdd.mapPartitions(iter => Iterator(foldIterator(iter)), true) + .reduce(_ ++ _) + } + +} + +/** + * A class that provides functions to recover a sequence dictionary from a generic RDD of records + * that are defined in Avro. This class assumes that the reference identification fields are + * defined inside of the given type. + * + * @note Avro classes that have specific constraints around sequence dictionary contents should + * not use this class. Examples include ADAMRecords and ADAMNucleotideContigs + * + * @tparam T A type defined in Avro that contains the reference identification fields. + * @param rdd RDD over which aggregation is supported. + */ +class AdamSpecificRecordSequenceDictionaryRDDAggregator[T <% SpecificRecord : Manifest](rdd: RDD[T]) + extends AdamSequenceDictionaryRDDAggregator[T](rdd) { + + def getSequenceRecordsFromElement (elem: T): Set[SequenceRecord] = { + Set(SequenceRecord.fromSpecificRecord(elem)) + } +} + +class AdamRecordRDDFunctions(rdd: RDD[ADAMRecord]) extends AdamSequenceDictionaryRDDAggregator[ADAMRecord](rdd) { + + def getSequenceRecordsFromElement (elem: ADAMRecord): scala.collection.Set[SequenceRecord] = { + SequenceRecord.fromADAMRecord(elem) + } + def adamSortReadsByReferencePosition(): RDD[ADAMRecord] = { log.info("Sorting reads by reference position") @@ -104,11 +173,6 @@ class AdamRecordRDDFunctions(rdd: RDD[ADAMRecord]) extends Serializable with Log }).sortByKey().map(p => p._2) } - def sequenceDictionary(): SequenceDictionary = - rdd.distinct().aggregate(SequenceDictionary())( - (dict: SequenceDictionary, rec: ADAMRecord) => dict ++ SequenceRecord.fromADAMRecord(rec), - (dict1: SequenceDictionary, dict2: SequenceDictionary) => dict1 ++ dict2) - def adamMarkDuplicates(): RDD[ADAMRecord] = { MarkDuplicates(rdd) } @@ -324,7 +388,7 @@ class AdamRodRDDFunctions(rdd: RDD[ADAMRod]) extends Serializable with Logging { } } -class AdamNucleotideContigFragmentRDDFunctions(rdd: RDD[ADAMNucleotideContigFragment]) extends Serializable with Logging { +class AdamNucleotideContigFragmentRDDFunctions(rdd: RDD[ADAMNucleotideContigFragment]) extends AdamSequenceDictionaryRDDAggregator[ADAMNucleotideContigFragment](rdd) { /** * Rewrites the contig IDs of a FASTA reference set to match the contig IDs present in a @@ -365,19 +429,6 @@ class AdamNucleotideContigFragmentRDDFunctions(rdd: RDD[ADAMNucleotideContigFrag rdd.flatMap(c => remapContig(c, bcastDict.value)) } - /** - * From this set of contigs, returns a sequence dictionary. - * - * @see AdamRecordRDDFunctions#sequenceDictionary - * - * @return Sequence dictionary representing this reference. - */ - def adamGetSequenceDictionary (): SequenceDictionary = - rdd.map(ctg => SequenceRecord.fromADAMContigFragment(ctg)) - .distinct() - .aggregate(SequenceDictionary())((dict: SequenceDictionary, rec: SequenceRecord) => dict ++ Seq(rec), - (dict1: SequenceDictionary, dict2: SequenceDictionary) => dict1 ++ dict2) - /** * From a set of contigs, returns the base sequence that corresponds to a region of the reference. * @@ -435,4 +486,10 @@ class AdamNucleotideContigFragmentRDDFunctions(rdd: RDD[ADAMNucleotideContigFrag } } } + + def getSequenceRecordsFromElement (elem: ADAMNucleotideContigFragment): Set[SequenceRecord] = { + // variant context contains a single locus + Set(SequenceRecord.fromADAMContigFragment(elem)) + } + } diff --git a/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/GenomicRegionPartitioner.scala b/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/GenomicRegionPartitioner.scala index 378d87c055..0b2b1e255e 100644 --- a/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/GenomicRegionPartitioner.scala +++ b/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/GenomicRegionPartitioner.scala @@ -100,7 +100,7 @@ object GenomicRegionPartitioner { def apply(N: Int, lengths: Map[Int, Long]) = new GenomicRegionPartitioner(N, lengths) def extractLengthMap(seqDict: SequenceDictionary): Map[Int, Long] = - Map(seqDict.records.map(rec => (rec.id, rec.length)): _*) + Map(seqDict.records.toSeq.map(rec => (rec.id, rec.length)): _*) } diff --git a/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/variation/ADAMVariationRDDFunctions.scala b/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/variation/ADAMVariationRDDFunctions.scala index 31171fa4e5..5e2caa8cb0 100644 --- a/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/variation/ADAMVariationRDDFunctions.scala +++ b/adam-core/src/main/scala/edu/berkeley/cs/amplab/adam/rdd/variation/ADAMVariationRDDFunctions.scala @@ -20,13 +20,25 @@ import edu.berkeley.cs.amplab.adam.avro.{ADAMGenotypeType, ADAMGenotype, ADAMDat import edu.berkeley.cs.amplab.adam.models.{ADAMVariantContext, SequenceDictionary, SequenceRecord} +import edu.berkeley.cs.amplab.adam.rdd.AdamSequenceDictionaryRDDAggregator import edu.berkeley.cs.amplab.adam.rich.RichADAMVariant import edu.berkeley.cs.amplab.adam.rich.RichADAMGenotype._ import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ -class ADAMVariantContextRDDFunctions(rdd: RDD[ADAMVariantContext]) extends Serializable with Logging { +class ADAMVariantContextRDDFunctions(rdd: RDD[ADAMVariantContext]) extends AdamSequenceDictionaryRDDAggregator[ADAMVariantContext](rdd) { + + /** + * For a single variant context, returns sequence record elements. + * + * @param elem Element from which to extract sequence records. + * @return A seq of sequence records. + */ + def getSequenceRecordsFromElement (elem: ADAMVariantContext): scala.collection.Set[SequenceRecord] = { + elem.genotypes.map(gt => SequenceRecord.fromSpecificRecord(gt.getVariant)).toSet + } + /** * Left outer join database variant annotations * @@ -39,11 +51,6 @@ class ADAMVariantContextRDDFunctions(rdd: RDD[ADAMVariantContext]) extends Seria } - def adamGetSequenceDictionary(): SequenceDictionary = - rdd.map(_.genotypes).distinct().aggregate(SequenceDictionary())( - (dict: SequenceDictionary, rec: Seq[ADAMGenotype]) => dict ++ rec.map((genotype : ADAMGenotype) => SequenceRecord.fromSpecificRecord(genotype.getVariant)), - (dict1: SequenceDictionary, dict2: SequenceDictionary) => dict1 ++ dict2) - def adamGetCallsetSamples(): List[String] = { rdd.flatMap(c => c.genotypes.map(_.getSampleId).distinct) .distinct diff --git a/adam-core/src/test/scala/edu/berkeley/cs/amplab/adam/models/SequenceDictionarySuite.scala b/adam-core/src/test/scala/edu/berkeley/cs/amplab/adam/models/SequenceDictionarySuite.scala index c0243def10..9b56ee4f45 100644 --- a/adam-core/src/test/scala/edu/berkeley/cs/amplab/adam/models/SequenceDictionarySuite.scala +++ b/adam-core/src/test/scala/edu/berkeley/cs/amplab/adam/models/SequenceDictionarySuite.scala @@ -160,7 +160,7 @@ class SequenceDictionarySuite extends FunSuite { assert(map(3) === 1) } - test("the additions + and += work correctly") { + test("the addition + works correctly") { val s1 = SequenceDictionary() val s2 = SequenceDictionary(record(1, "foo")) val s3 = SequenceDictionary(record(1, "foo"), record(2, "bar")) @@ -168,18 +168,9 @@ class SequenceDictionarySuite extends FunSuite { assert(s1 + record(1, "foo") === s2) assert(s2 + record(1, "foo") === s2) assert(s2 + record(2, "bar") === s3) - - s1 += record(1, "foo") - assert(s1 === s2) - - s1 += record(1, "foo") - assert(s1 === s2) - - s1 += record(2, "bar") - assert(s1 === s3) } - test("the append operations ++ and ++= work correctly") { + test("the append operation ++ works correctly") { val s1 = SequenceDictionary() val s2a = SequenceDictionary(record(1, "foo")) val s2b = SequenceDictionary(record(2, "bar")) @@ -189,15 +180,6 @@ class SequenceDictionarySuite extends FunSuite { assert(s1 ++ s2a === s2a) assert(s1 ++ s2b === s2b) assert(s2a ++ s2b === s3) - - s1 ++= s2a - assert(s1 === s2a) - - s1 ++= s2b - assert(s1 === s3) - - s1 ++= s3 - assert(s1 === s3) } test("containsRefName works correctly") {