Skip to content

Commit

Permalink
Merge pull request #592 from fnothaft/remove-mapping-context
Browse files Browse the repository at this point in the history
[ADAM-513] Remove ReferenceMappable trait.
  • Loading branch information
laserson committed Mar 2, 2015
2 parents 4027323 + 4c700b8 commit 28b34b3
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.bdgenomics.adam.projections.Projection
import org.bdgenomics.adam.projections.AlignmentRecordField._
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.adam.rdd.BroadcastRegionJoin
import org.bdgenomics.adam.rich.ReferenceMappingContext._
import org.bdgenomics.formats.avro.AlignmentRecord
import scala.io._

Expand Down Expand Up @@ -90,8 +89,8 @@ class CalculateDepth(protected val args: CalculateDepthArgs) extends ADAMSparkCo
val variantNames = vcf.collect().toMap

val joinedRDD: RDD[(ReferenceRegion, AlignmentRecord)] =
if (args.cartesian) BroadcastRegionJoin.cartesianFilter(variantPositions, mappedRDD)
else BroadcastRegionJoin.partitionAndJoin(sc, variantPositions, mappedRDD)
if (args.cartesian) BroadcastRegionJoin.cartesianFilter(variantPositions.keyBy(v => v), mappedRDD.keyBy(ReferenceRegion(_).get))
else BroadcastRegionJoin.partitionAndJoin(sc, variantPositions.keyBy(v => v), mappedRDD.keyBy(ReferenceRegion(_).get))

val depths: RDD[(ReferenceRegion, Int)] =
joinedRDD.map { case (region, record) => (region, 1) }.reduceByKey(_ + _).sortByKey()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
package org.bdgenomics.adam.rdd

import org.bdgenomics.adam.models.{ SequenceDictionary, ReferenceMapping, ReferenceRegion }
import org.bdgenomics.adam.models.{ SequenceDictionary, ReferenceRegion }
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import scala.Predef._
Expand Down Expand Up @@ -53,12 +53,8 @@ object BroadcastRegionJoin {
* operation. The result is the region-join.
*
* @param sc A SparkContext for the cluster that will perform the join
* @param baseRDD The 'left' side of the join, a set of values which correspond (through an implicit
* ReferenceMapping) to regions on the genome.
* @param joinedRDD The 'right' side of the join, a set of values which correspond (through an implicit
* ReferenceMapping) to regions on the genome
* @param tMapping implicit reference mapping for baseRDD regions
* @param uMapping implicit reference mapping for joinedRDD regions
* @param baseRDD The 'left' side of the join
* @param joinedRDD The 'right' side of the join
* @param tManifest implicit type of baseRDD
* @param uManifest implicit type of joinedRDD
* @tparam T type of baseRDD
Expand All @@ -67,11 +63,9 @@ object BroadcastRegionJoin {
* corresponding to x overlaps the region corresponding to y.
*/
def partitionAndJoin[T, U](sc: SparkContext,
baseRDD: RDD[T],
joinedRDD: RDD[U])(implicit tMapping: ReferenceMapping[T],
uMapping: ReferenceMapping[U],
tManifest: ClassTag[T],
uManifest: ClassTag[U]): RDD[(T, U)] = {
baseRDD: RDD[(ReferenceRegion, T)],
joinedRDD: RDD[(ReferenceRegion, U)])(implicit tManifest: ClassTag[T],
uManifest: ClassTag[U]): RDD[(T, U)] = {

/**
* Original Join Design:
Expand Down Expand Up @@ -100,7 +94,7 @@ object BroadcastRegionJoin {
// and collect them.
val collectedLeft: Seq[(String, Iterable[ReferenceRegion])] =
baseRDD
.map(t => (tMapping.getReferenceName(t), tMapping.getReferenceRegion(t))) // RDD[(String,ReferenceRegion)]
.map(kv => (kv._1.referenceName, kv._1)) // RDD[(String,ReferenceRegion)]
.groupBy(_._1) // RDD[(String,Seq[(String,ReferenceRegion)])]
.map(t => (t._1, t._2.map(_._2))) // RDD[(String,Seq[ReferenceRegion])]
.collect() // Iterable[(String,Seq[ReferenceRegion])]
Expand All @@ -115,26 +109,26 @@ object BroadcastRegionJoin {
val regions = sc.broadcast(multiNonOverlapping)

// each element of the left-side RDD should have exactly one partition.
val smallerKeyed: RDD[(ReferenceRegion, T)] =
baseRDD.keyBy(t => regions.value.regionsFor(t).head)
val smallerKeyed: RDD[(ReferenceRegion, (ReferenceRegion, T))] =
baseRDD.map(t => (regions.value.regionsFor(t).head, t))

// each element of the right-side RDD may have 0, 1, or more than 1 corresponding partition.
val largerKeyed: RDD[(ReferenceRegion, U)] =
val largerKeyed: RDD[(ReferenceRegion, (ReferenceRegion, U))] =
joinedRDD.filter(regions.value.filter(_))
.flatMap(t => regions.value.regionsFor(t).map((r: ReferenceRegion) => (r, t)))

// this is (essentially) performing a cartesian product within each partition...
val joined: RDD[(ReferenceRegion, (T, U))] =
val joined: RDD[(ReferenceRegion, ((ReferenceRegion, T), (ReferenceRegion, U)))] =
smallerKeyed.join(largerKeyed)

// ... so we need to filter the final pairs to make sure they're overlapping.
val filtered: RDD[(ReferenceRegion, (T, U))] = joined.filter({
case (rr: ReferenceRegion, (t: T, u: U)) =>
tMapping.getReferenceRegion(t).overlaps(uMapping.getReferenceRegion(u))
val filtered: RDD[(ReferenceRegion, ((ReferenceRegion, T), (ReferenceRegion, U)))] = joined.filter(kv => {
val (rr: ReferenceRegion, (t: (ReferenceRegion, T), u: (ReferenceRegion, U))) = kv
t._1.overlaps(u._1)
})

// finally, erase the partition key and return the result.
filtered.map(rrtu => rrtu._2)
filtered.map(rrtu => (rrtu._2._1._2, rrtu._2._2._2))
}

/**
Expand All @@ -146,15 +140,13 @@ object BroadcastRegionJoin {
* realistic sized sets.
*
*/
def cartesianFilter[T, U](baseRDD: RDD[T],
joinedRDD: RDD[U])(implicit tMapping: ReferenceMapping[T],
uMapping: ReferenceMapping[U],
tManifest: ClassTag[T],
uManifest: ClassTag[U]): RDD[(T, U)] = {
def cartesianFilter[T, U](baseRDD: RDD[(ReferenceRegion, T)],
joinedRDD: RDD[(ReferenceRegion, U)])(implicit tManifest: ClassTag[T],
uManifest: ClassTag[U]): RDD[(T, U)] = {
baseRDD.cartesian(joinedRDD).filter({
case (t: T, u: U) =>
tMapping.getReferenceRegion(t).overlaps(uMapping.getReferenceRegion(u))
})
case (t: (ReferenceRegion, T), u: (ReferenceRegion, U)) =>
t._1.overlaps(u._1)
}).map(p => (p._1._2, p._2._2))
}
}

Expand Down Expand Up @@ -285,8 +277,8 @@ class NonoverlappingRegions(regions: Iterable[ReferenceRegion]) extends Serializ
* @return An Iterable[ReferenceRegion], where each element of the Iterable is a nonoverlapping-region
* defined by 1 or more input-set regions.
*/
def regionsFor[U](regionable: U)(implicit mapping: ReferenceMapping[U]): Iterable[ReferenceRegion] =
findOverlappingRegions(mapping.getReferenceRegion(regionable))
def regionsFor[U](regionable: (ReferenceRegion, U)): Iterable[ReferenceRegion] =
findOverlappingRegions(regionable._1)

/**
* A quick filter, to find out if we even need to examine a particular input value for keying by
Expand All @@ -299,9 +291,8 @@ class NonoverlappingRegions(regions: Iterable[ReferenceRegion]) extends Serializ
* @return a boolean -- the input value should only participate in the regionJoin if the return value
* here is 'true'.
*/
def hasRegionsFor[U](regionable: U)(implicit mapping: ReferenceMapping[U]): Boolean = {
val region = mapping.getReferenceRegion(regionable)
!(region.end <= endpoints.head || region.start >= endpoints.last)
def hasRegionsFor[U](regionable: (ReferenceRegion, U)): Boolean = {
!(regionable._1.end <= endpoints.head || regionable._1.start >= endpoints.last)
}

override def toString: String =
Expand All @@ -310,8 +301,8 @@ class NonoverlappingRegions(regions: Iterable[ReferenceRegion]) extends Serializ

object NonoverlappingRegions {

def apply[T](values: Seq[T])(implicit refMapping: ReferenceMapping[T]) =
new NonoverlappingRegions(values.map(value => refMapping.getReferenceRegion(value)))
def apply[T](values: Seq[(ReferenceRegion, T)]) =
new NonoverlappingRegions(values.map(_._1))

def alternating[T](seq: Seq[T], includeFirst: Boolean): Seq[T] = {
val inds = if (includeFirst) { 0 until seq.size } else { 1 until seq.size + 1 }
Expand Down Expand Up @@ -340,23 +331,17 @@ class MultiContigNonoverlappingRegions(
val regionMap: Map[String, NonoverlappingRegions] =
Map(regions.map(r => (r._1, new NonoverlappingRegions(r._2))): _*)

def regionsFor[U](regionable: U)(implicit mapping: ReferenceMapping[U]): Iterable[ReferenceRegion] =
regionMap.get(mapping.getReferenceName(regionable)) match {
case None => Seq()
case Some(nr) => nr.regionsFor(regionable)
}
def regionsFor[U](regionable: (ReferenceRegion, U)): Iterable[ReferenceRegion] =
regionMap.get(regionable._1.referenceName).fold(Iterable[ReferenceRegion]())(_.regionsFor(regionable))

def filter[U](value: U)(implicit mapping: ReferenceMapping[U]): Boolean =
regionMap.get(mapping.getReferenceName(value)) match {
case None => false
case Some(nr) => nr.hasRegionsFor(value)
}
def filter[U](value: (ReferenceRegion, U)): Boolean =
regionMap.get(value._1.referenceName).fold(false)(_.hasRegionsFor(value))
}

object MultiContigNonoverlappingRegions {
def apply[T](values: Seq[T])(implicit mapping: ReferenceMapping[T]): MultiContigNonoverlappingRegions = {
def apply[T](values: Seq[(ReferenceRegion, T)]): MultiContigNonoverlappingRegions = {
new MultiContigNonoverlappingRegions(
values.map(v => (mapping.getReferenceName(v), mapping.getReferenceRegion(v)))
values.map(kv => (kv._1.referenceName, kv._1))
.groupBy(t => t._1)
.map(t => (t._1, t._2.map(k => k._2)))
.toSeq)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
package org.bdgenomics.adam.rdd

import org.bdgenomics.adam.models.{ ReferenceRegion, ReferenceMapping, ReferencePosition, SequenceDictionary }
import org.bdgenomics.adam.models.{ ReferenceRegion, ReferencePosition, SequenceDictionary }
import org.apache.spark.{ Logging, Partitioner }
import scala.math._

Expand Down Expand Up @@ -99,16 +99,12 @@ object GenomicPositionPartitioner {
Map(seqDict.records.toSeq.map(rec => (rec.name.toString, rec.length)): _*)
}

case class GenomicRegionPartitioner[T: ReferenceMapping](partitionSize: Long, seqLengths: Map[String, Long], start: Boolean = true) extends Partitioner with Logging {
case class GenomicRegionPartitioner(partitionSize: Long, seqLengths: Map[String, Long], start: Boolean = true) extends Partitioner with Logging {
private val names: Seq[String] = seqLengths.keys.toSeq.sortWith(_ < _)
private val lengths: Seq[Long] = names.map(seqLengths(_))
private val parts: Seq[Int] = lengths.map(v => round(ceil(v.toDouble / partitionSize)).toInt)
private val cumulParts: Map[String, Int] = Map(names.zip(parts.scan(0)(_ + _)): _*)

private def extractReferenceRegion(k: T)(implicit tMapping: ReferenceMapping[T]): ReferenceRegion = {
tMapping.getReferenceRegion(k)
}

private def computePartition(refReg: ReferenceRegion): Int = {
val pos = if (start) refReg.start else (refReg.end - 1)
(cumulParts(refReg.referenceName) + pos / partitionSize).toInt
Expand All @@ -118,13 +114,13 @@ case class GenomicRegionPartitioner[T: ReferenceMapping](partitionSize: Long, se

override def getPartition(key: Any): Int = {
key match {
case mappable: T => computePartition(extractReferenceRegion(mappable))
case _ => throw new IllegalArgumentException("Only ReferenceMappable values can be partitioned by GenomicRegionPartitioner")
case region: ReferenceRegion => computePartition(region)
case _ => throw new IllegalArgumentException("Only ReferenceMappable values can be partitioned by GenomicRegionPartitioner")
}
}
}

object GenomicRegionPartitioner {
def apply[T: ReferenceMapping](partitionSize: Long, seqDict: SequenceDictionary): GenomicRegionPartitioner[T] =
def apply(partitionSize: Long, seqDict: SequenceDictionary): GenomicRegionPartitioner =
GenomicRegionPartitioner(partitionSize, GenomicPositionPartitioner.extractLengthMap(seqDict))
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ package org.bdgenomics.adam.rdd
import org.apache.spark.{ Logging, Partitioner, SparkContext }
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.bdgenomics.adam.models.{ SequenceDictionary, ReferenceRegion, ReferenceMapping }

import org.bdgenomics.adam.models.{ SequenceDictionary, ReferenceRegion }
import scala.collection.mutable.ListBuffer
import scala.math._
import scala.reflect.ClassTag
Expand All @@ -40,15 +39,11 @@ object ShuffleRegionJoin {
* the object in each bin. Finally, each bin independently performs a chromsweep sort-merge join.
*
* @param sc A SparkContext for the cluster that will perform the join
* @param leftRDD The 'left' side of the join, a set of values which correspond (through an implicit
* ReferenceMapping) to regions on the genome.
* @param rightRDD The 'right' side of the join, a set of values which correspond (through an implicit
* ReferenceMapping) to regions on the genome
* @param leftRDD The 'left' side of the join
* @param rightRDD The 'right' side of the join
* @param seqDict A SequenceDictionary -- every region corresponding to either the leftRDD or rightRDD
* values must be mapped to a chromosome with an entry in this dictionary.
* @param partitionSize The size of the genome bin in nucleotides. Controls the parallelism of the join.
* @param tMapping implicit reference mapping for leftRDD regions
* @param uMapping implicit reference mapping for rightRDD regions
* @param tManifest implicit type of leftRDD
* @param uManifest implicit type of rightRDD
* @tparam T type of leftRDD
Expand All @@ -57,12 +52,10 @@ object ShuffleRegionJoin {
* corresponding to x overlaps the region corresponding to y.
*/
def partitionAndJoin[T, U](sc: SparkContext,
leftRDD: RDD[T],
rightRDD: RDD[U],
leftRDD: RDD[(ReferenceRegion, T)],
rightRDD: RDD[(ReferenceRegion, U)],
seqDict: SequenceDictionary,
partitionSize: Long)(implicit tMapping: ReferenceMapping[T],
uMapping: ReferenceMapping[U],
tManifest: ClassTag[T],
partitionSize: Long)(implicit tManifest: ClassTag[T],
uManifest: ClassTag[U]): RDD[(T, U)] = {
// Create the set of bins across the genome for parallel processing
val seqLengths = Map(seqDict.records.toSeq.map(rec => (rec.name.toString, rec.length)): _*)
Expand All @@ -71,15 +64,15 @@ object ShuffleRegionJoin {
// Key each RDD element to its corresponding bin
// Elements may be replicated if they overlap multiple bins
val keyedLeft: RDD[((ReferenceRegion, Int), T)] =
leftRDD.flatMap(x => {
val region = tMapping.getReferenceRegion(x)
leftRDD.flatMap(kv => {
val (region, x) = kv
val lo = bins.value.getStartBin(region)
val hi = bins.value.getEndBin(region)
(lo to hi).map(i => ((region, i), x))
})
val keyedRight: RDD[((ReferenceRegion, Int), U)] =
rightRDD.flatMap(y => {
val region = uMapping.getReferenceRegion(y)
rightRDD.flatMap(kv => {
val (region, y) = kv
val lo = bins.value.getStartBin(region)
val hi = bins.value.getEndBin(region)
(lo to hi).map(i => ((region, i), y))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.bdgenomics.adam.models._
import org.bdgenomics.adam.rich.ReferenceMappingContext.FeatureReferenceMapping
import org.bdgenomics.formats.avro.{ Strand, Feature }
import scala.collection.JavaConversions._

Expand Down Expand Up @@ -71,23 +70,23 @@ class GeneFeatureRDDFunctions(featureRDD: RDD[Feature]) extends Serializable wit
case ("exon", ftr: Feature) =>
val ids: Seq[String] = ftr.getParentIds.map(_.toString)
ids.map(transcriptId => (transcriptId,
Exon(ftr.getFeatureId.toString, transcriptId, strand(ftr.getStrand), FeatureReferenceMapping.getReferenceRegion(ftr))))
Exon(ftr.getFeatureId.toString, transcriptId, strand(ftr.getStrand), ReferenceRegion(ftr))))
}.groupByKey()

val cdsByTranscript: RDD[(String, Iterable[CDS])] =
typePartitioned.filter(_._1 == "CDS").flatMap {
case ("CDS", ftr: Feature) =>
val ids: Seq[String] = ftr.getParentIds.map(_.toString)
ids.map(transcriptId => (transcriptId,
CDS(transcriptId, strand(ftr.getStrand), FeatureReferenceMapping.getReferenceRegion(ftr))))
CDS(transcriptId, strand(ftr.getStrand), ReferenceRegion(ftr))))
}.groupByKey()

val utrsByTranscript: RDD[(String, Iterable[UTR])] =
typePartitioned.filter(_._1 == "UTR").flatMap {
case ("UTR", ftr: Feature) =>
val ids: Seq[String] = ftr.getParentIds.map(_.toString)
ids.map(transcriptId => (transcriptId,
UTR(transcriptId, strand(ftr.getStrand), FeatureReferenceMapping.getReferenceRegion(ftr))))
UTR(transcriptId, strand(ftr.getStrand), ReferenceRegion(ftr))))
}.groupByKey()

// Step #3
Expand Down
Loading

0 comments on commit 28b34b3

Please sign in to comment.