Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored SequenceDictionary #162

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class ListDict(protected val args: ListDictArgs) extends AdamSparkCommand[ListDi
ParquetLogger.hadoopLoggerLevel(Level.SEVERE)

val dict = sc.adamDictionaryLoad[ADAMRecord](args.inputPath)
dict.records.toList.sortBy(_.id).foreach {

dict.recordsIn.sortBy(_.id).foreach {
rec: SequenceRecord =>
println("%d\t%s\t%d".format(rec.id, rec.name, rec.length))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.math.Ordering.Implicits._
* SequenceDictionary contains the (bijective) map between Ints (the referenceId) and Strings (the referenceName)
* from the header of a BAM file, or the combined result of multiple such SequenceDictionaries.
*/
class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializable {
class SequenceDictionary(val recordsIn: Array[SequenceRecord]) extends Serializable {

// Intermediate value used to ensure that no referenceName or referenceId is listed twice with a different
// referenceId or referenceName (respectively). Notice the "toSet", which means it's okay to pass an Iterable
Expand All @@ -48,13 +48,13 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab
}

// Maps referenceId -> SequenceRecord
private val recordIndices: mutable.Map[Int, SequenceRecord] =
private lazy val recordIndices: mutable.Map[Int, SequenceRecord] =
mutable.Map(recordsIn.map {
rec => (rec.id, rec)
}.toSeq: _*)

// Maps referenceName -> SequenceRecord
private val recordNames: mutable.Map[CharSequence, SequenceRecord] =
private lazy val recordNames: mutable.Map[CharSequence, SequenceRecord] =
mutable.Map(recordsIn.map {
// Call toString explicitly, since otherwise we were picking up an Avro-specific Utf8 value here,
// which was making the containsRefName method below fail in a hard-to-understand way.
Expand All @@ -71,12 +71,18 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab
/**
* Returns the sequence record associated with a specific contig name.
*
* @throws AssertionError Throws assertion error if sequence corresponding to contig name
* is not found.
*
* @param name Name to search for.
* @return SequenceRecord associated with this record.
*/
def apply(name: CharSequence): SequenceRecord = {
// must explicitly call toString - see note at recordNames creation RE: Avro & Utf8
recordNames(name.toString)
val rec = recordsIn.find(kv => kv.name.toString == name.toString)

assert(rec.isDefined, "Could not find key " + name + " in dictionary.")
rec.get
}

/**
Expand All @@ -87,7 +93,7 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab
*/
def containsRefName(name : CharSequence) : Boolean = {
// must explicitly call toString - see note at recordNames creation RE: Avro & Utf8
recordNames.contains(name.toString)
!recordsIn.forall(kv => kv.name.toString != name.toString)
}

/**
Expand Down Expand Up @@ -178,7 +184,6 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab

/**
* See the note to mapTo, above.

* The results of this remap and mapTo should be to produce a "compatible" dictionary,
* i.e. for all d1 and d2,
*
Expand All @@ -194,47 +199,34 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab
def remapIndex(i: Int): Int =
if (idTransform.contains(i)) idTransform(i) else i

SequenceDictionary(idNamePairs.map {
new SequenceDictionary(idNamePairs.map {
case (id, name) =>
recordIndices(id).withReferenceId(remapIndex(id))
}.toSeq: _*)
}.toArray)
}

def records: Seq[SequenceRecord] = recordIndices.values.toSeq

def +(rec: SequenceRecord): SequenceDictionary =
new SequenceDictionary(recordsIn ++ List(rec))
def records: Set[SequenceRecord] = recordIndices.values.toSet

def +=(rec: SequenceRecord): SequenceDictionary = {

recordIndices.put(rec.id, rec)
recordNames.put(rec.name, rec)
this
private[models] def cleanAndMerge(a1: Array[SequenceRecord],
a2: Array[SequenceRecord]): Array[SequenceRecord] = {
val a2filt = a2.filter(k => !a1.contains(k))

a1 ++ a2filt
}

def ++(dict: SequenceDictionary): SequenceDictionary =
new SequenceDictionary(recordsIn ++ dict.records)

def ++(recs: Seq[SequenceRecord]): SequenceDictionary =
recs.foldRight(this)((rec, dict) => dict + rec)

def ++=(recs: Seq[SequenceRecord]): SequenceDictionary = {
recs.foreach {
rec => this += rec
def +(record: SequenceRecord): SequenceDictionary = {
if (recordsIn.contains(record)) {
new SequenceDictionary(recordsIn)
} else {
new SequenceDictionary(recordsIn :+ record)
}
this
}

def ++=(dict: SequenceDictionary): SequenceDictionary = {
dict.recordIndices.keys.foreach {
idx => {
val newrec = dict.recordIndices(idx)
recordIndices.put(newrec.id, newrec)
recordNames.put(newrec.name, newrec)
}
}
this
}
def ++(dict: SequenceDictionary): SequenceDictionary =
new SequenceDictionary(cleanAndMerge(recordsIn, dict.recordsIn))

def ++(recs: Array[SequenceRecord]): SequenceDictionary =
new SequenceDictionary(cleanAndMerge(recordsIn, recs))

/**
* Tests whether two dictionaries are compatible, where "compatible" means that
Expand Down Expand Up @@ -287,7 +279,7 @@ class SequenceDictionary(recordsIn: Iterable[SequenceRecord]) extends Serializab

object SequenceDictionary {

def apply(recordsIn: SequenceRecord*) = new SequenceDictionary(recordsIn)
def apply(recordsIn: SequenceRecord*) = new SequenceDictionary(recordsIn.toArray)

/**
* Extracts a SAM sequence dictionary from a SAM file header and returns an
Expand Down Expand Up @@ -445,7 +437,7 @@ object SequenceRecord {
* @param rec The ADAMRecord from which to extract the SequenceRecord entries
* @return a list of all SequenceRecord entries derivable from this record.
*/
def fromADAMRecord(rec: ADAMRecord): Seq[SequenceRecord] = {
def fromADAMRecord(rec: ADAMRecord): Set[SequenceRecord] = {

assert(rec != null, "ADAMRecord was null")

Expand All @@ -456,29 +448,29 @@ object SequenceRecord {

val left =
if (rec.getReadMapped)
List(SequenceRecord(rec.getReferenceId, rec.getReferenceName, rec.getReferenceLength, rec.getReferenceUrl))
Set(SequenceRecord(rec.getReferenceId, rec.getReferenceName, rec.getReferenceLength, rec.getReferenceUrl))
else
List()
Set()

val right =
if (rec.getMateMapped)
List(SequenceRecord(rec.getMateReferenceId, rec.getMateReference, rec.getMateReferenceLength, rec.getMateReferenceUrl))
Set(SequenceRecord(rec.getMateReferenceId, rec.getMateReference, rec.getMateReferenceLength, rec.getMateReferenceUrl))
else
List()
Set()

left ++ right

} else {
List()
Set()
}

} else {

if (rec.getReadMapped) {
List(SequenceRecord(rec.getReferenceId, rec.getReferenceName, rec.getReferenceLength, rec.getReferenceUrl))
Set(SequenceRecord(rec.getReferenceId, rec.getReferenceName, rec.getReferenceLength, rec.getReferenceUrl))
} else {
// If the read isn't mapped, then ignore the fields altogether.
List()
Set()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,76 @@ class AdamRDDFunctions[T <% SpecificRecord : Manifest](rdd: RDD[T]) extends Seri

}

class AdamRecordRDDFunctions(rdd: RDD[ADAMRecord]) extends Serializable with Logging {
/**
* A class that provides functions to recover a sequence dictionary from a generic RDD of records.
*
* @tparam T Type contained in this RDD.
* @param rdd RDD over which aggregation is supported.
*/
abstract class AdamSequenceDictionaryRDDAggregator[T](rdd: RDD[T]) extends Serializable with Logging {
/**
* For a single RDD element, returns 0+ sequence record elements.
*
* @param elem Element from which to extract sequence records.
* @return A seq of sequence records.
*/
def getSequenceRecordsFromElement (elem: T): scala.collection.Set[SequenceRecord]

/**
* Aggregates together a sequence dictionary from the different individual reference sequences
* used in this dataset.
*
* @return A sequence dictionary describing the reference contigs in this dataset.
*/
def adamGetSequenceDictionary (): SequenceDictionary = {
def mergeRecords(l: List[SequenceRecord], rec: T): List[SequenceRecord] = {
val recs = getSequenceRecordsFromElement(rec)

recs.foldLeft(l)((li: List[SequenceRecord], r: SequenceRecord) => {
if (!li.contains(r)) {
r :: li
} else {
li
}
})
}

def foldIterator(iter: Iterator[T]): SequenceDictionary = {
val recs = iter.foldLeft(List[SequenceRecord]())(mergeRecords)
new SequenceDictionary(recs.toArray)
}

rdd.mapPartitions(iter => Iterator(foldIterator(iter)), true)
.reduce(_ ++ _)
}

}

/**
* A class that provides functions to recover a sequence dictionary from a generic RDD of records
* that are defined in Avro. This class assumes that the reference identification fields are
* defined inside of the given type.
*
* @note Avro classes that have specific constraints around sequence dictionary contents should
* not use this class. Examples include ADAMRecords and ADAMNucleotideContigs
*
* @tparam T A type defined in Avro that contains the reference identification fields.
* @param rdd RDD over which aggregation is supported.
*/
class AdamSpecificRecordSequenceDictionaryRDDAggregator[T <% SpecificRecord : Manifest](rdd: RDD[T])
extends AdamSequenceDictionaryRDDAggregator[T](rdd) {

def getSequenceRecordsFromElement (elem: T): Set[SequenceRecord] = {
Set(SequenceRecord.fromSpecificRecord(elem))
}
}

class AdamRecordRDDFunctions(rdd: RDD[ADAMRecord]) extends AdamSequenceDictionaryRDDAggregator[ADAMRecord](rdd) {

def getSequenceRecordsFromElement (elem: ADAMRecord): scala.collection.Set[SequenceRecord] = {
SequenceRecord.fromADAMRecord(elem)
}

def adamSortReadsByReferencePosition(): RDD[ADAMRecord] = {
log.info("Sorting reads by reference position")

Expand Down Expand Up @@ -104,11 +173,6 @@ class AdamRecordRDDFunctions(rdd: RDD[ADAMRecord]) extends Serializable with Log
}).sortByKey().map(p => p._2)
}

def sequenceDictionary(): SequenceDictionary =
rdd.distinct().aggregate(SequenceDictionary())(
(dict: SequenceDictionary, rec: ADAMRecord) => dict ++ SequenceRecord.fromADAMRecord(rec),
(dict1: SequenceDictionary, dict2: SequenceDictionary) => dict1 ++ dict2)

def adamMarkDuplicates(): RDD[ADAMRecord] = {
MarkDuplicates(rdd)
}
Expand Down Expand Up @@ -324,7 +388,7 @@ class AdamRodRDDFunctions(rdd: RDD[ADAMRod]) extends Serializable with Logging {
}
}

class AdamNucleotideContigFragmentRDDFunctions(rdd: RDD[ADAMNucleotideContigFragment]) extends Serializable with Logging {
class AdamNucleotideContigFragmentRDDFunctions(rdd: RDD[ADAMNucleotideContigFragment]) extends AdamSequenceDictionaryRDDAggregator[ADAMNucleotideContigFragment](rdd) {

/**
* Rewrites the contig IDs of a FASTA reference set to match the contig IDs present in a
Expand Down Expand Up @@ -365,19 +429,6 @@ class AdamNucleotideContigFragmentRDDFunctions(rdd: RDD[ADAMNucleotideContigFrag
rdd.flatMap(c => remapContig(c, bcastDict.value))
}

/**
* From this set of contigs, returns a sequence dictionary.
*
* @see AdamRecordRDDFunctions#sequenceDictionary
*
* @return Sequence dictionary representing this reference.
*/
def adamGetSequenceDictionary (): SequenceDictionary =
rdd.map(ctg => SequenceRecord.fromADAMContigFragment(ctg))
.distinct()
.aggregate(SequenceDictionary())((dict: SequenceDictionary, rec: SequenceRecord) => dict ++ Seq(rec),
(dict1: SequenceDictionary, dict2: SequenceDictionary) => dict1 ++ dict2)

/**
* From a set of contigs, returns the base sequence that corresponds to a region of the reference.
*
Expand Down Expand Up @@ -435,4 +486,10 @@ class AdamNucleotideContigFragmentRDDFunctions(rdd: RDD[ADAMNucleotideContigFrag
}
}
}

def getSequenceRecordsFromElement (elem: ADAMNucleotideContigFragment): Set[SequenceRecord] = {
// variant context contains a single locus
Set(SequenceRecord.fromADAMContigFragment(elem))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ object GenomicRegionPartitioner {
def apply(N: Int, lengths: Map[Int, Long]) = new GenomicRegionPartitioner(N, lengths)

def extractLengthMap(seqDict: SequenceDictionary): Map[Int, Long] =
Map(seqDict.records.map(rec => (rec.id, rec.length)): _*)
Map(seqDict.records.toSeq.map(rec => (rec.id, rec.length)): _*)
}


Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,25 @@ import edu.berkeley.cs.amplab.adam.avro.{ADAMGenotypeType, ADAMGenotype, ADAMDat
import edu.berkeley.cs.amplab.adam.models.{ADAMVariantContext,
SequenceDictionary,
SequenceRecord}
import edu.berkeley.cs.amplab.adam.rdd.AdamSequenceDictionaryRDDAggregator
import edu.berkeley.cs.amplab.adam.rich.RichADAMVariant
import edu.berkeley.cs.amplab.adam.rich.RichADAMGenotype._
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._

class ADAMVariantContextRDDFunctions(rdd: RDD[ADAMVariantContext]) extends Serializable with Logging {
class ADAMVariantContextRDDFunctions(rdd: RDD[ADAMVariantContext]) extends AdamSequenceDictionaryRDDAggregator[ADAMVariantContext](rdd) {

/**
* For a single variant context, returns sequence record elements.
*
* @param elem Element from which to extract sequence records.
* @return A seq of sequence records.
*/
def getSequenceRecordsFromElement (elem: ADAMVariantContext): scala.collection.Set[SequenceRecord] = {
elem.genotypes.map(gt => SequenceRecord.fromSpecificRecord(gt.getVariant)).toSet
}

/**
* Left outer join database variant annotations
*
Expand All @@ -39,11 +51,6 @@ class ADAMVariantContextRDDFunctions(rdd: RDD[ADAMVariantContext]) extends Seria

}

def adamGetSequenceDictionary(): SequenceDictionary =
rdd.map(_.genotypes).distinct().aggregate(SequenceDictionary())(
(dict: SequenceDictionary, rec: Seq[ADAMGenotype]) => dict ++ rec.map((genotype : ADAMGenotype) => SequenceRecord.fromSpecificRecord(genotype.getVariant)),
(dict1: SequenceDictionary, dict2: SequenceDictionary) => dict1 ++ dict2)

def adamGetCallsetSamples(): List[String] = {
rdd.flatMap(c => c.genotypes.map(_.getSampleId).distinct)
.distinct
Expand Down
Loading