Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small bugfixes and cleanups to BQSR #38

Merged
merged 5 commits into from
Jan 11, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
*/
package edu.berkeley.cs.amplab.adam.cli

import org.apache.spark.Logging
import scala.Some

object AdamMain {
object AdamMain extends Logging {

private val commands = List(Bam2Adam,
Transform,
Expand Down Expand Up @@ -48,6 +49,7 @@ object AdamMain {
}

def main(args: Array[String]) {
log.info("ADAM invoked with args: %s".format(argsToString(args)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed the text to be clearer that this is logging each invocation, and have included the arguments in a format suitable for copy-and-paste (e.g. to reproduce a run or re-run with slight changes to arguments).

However, I haven't moved it to the else branch to have more complete logs -- is this OK?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me. I like that you added the argsToString feature to copy-paste an ADAM command-line.

if (args.size < 1) {
printCommands()
} else {
Expand All @@ -57,4 +59,11 @@ object AdamMain {
}
}
}

// Attempts to format the `args` array into a string in a way
// suitable for copying and pasting back into the shell.
private def argsToString(args: Array[String]): String = {
def escapeArg(s: String) = "\"" + s.replaceAll("\\\"", "\\\\\"") + "\""
args.map(escapeArg).mkString(" ")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import edu.berkeley.cs.amplab.adam.util.ParquetLogger
import org.kohsuke.args4j.{Argument, Option => Args4jOption}
import edu.berkeley.cs.amplab.adam.avro.ADAMRecord
import edu.berkeley.cs.amplab.adam.rdd.AdamContext._
import edu.berkeley.cs.amplab.adam.models.SnpTable
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.rdd.RDD
import java.io.File
Expand All @@ -46,7 +47,7 @@ class TransformArgs extends Args4jBase with ParquetArgs with SparkArgs {
@Args4jOption(required = false, name = "-recalibrate_base_qualities", usage = "Recalibrate the base quality scores (ILLUMINA only)")
var recalibrateBaseQualities: Boolean = false
@Args4jOption(required = false, name = "-dbsnp_sites", usage = "dbsnp sites file")
var dbsnpSitesFile: File = null
var dbsnpSitesFile: String = null
@Args4jOption(required = false, name = "-coalesce", usage = "Set the number of partitions written to the ADAM output directory")
var coalesce: Int = -1
}
Expand Down Expand Up @@ -74,7 +75,8 @@ class Transform(protected val args: TransformArgs) extends AdamSparkCommand[Tran

if (args.recalibrateBaseQualities) {
log.info("Recalibrating base qualities")
adamRecords = adamRecords.adamBQSR(args.dbsnpSitesFile)
val dbSNP = loadSnpTable(sc)
adamRecords = adamRecords.adamBQSR(dbSNP)
}

// NOTE: For now, sorting needs to be the last transform
Expand All @@ -87,6 +89,17 @@ class Transform(protected val args: TransformArgs) extends AdamSparkCommand[Tran
compressCodec = args.compressionCodec, disableDictionaryEncoding = args.disableDictionary)
}

// FIXME: why doesn't this complain if the file doesn't exist?
def loadSnpTable(sc: SparkContext): SnpTable = {
if(args.dbsnpSitesFile != null) {
log.info("Loading SNP table")
//SnpTable(sc.textFile(args.dbsnpSitesFile))
SnpTable(new File(args.dbsnpSitesFile))
} else {
SnpTable()
}
}

}


Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package edu.berkeley.cs.amplab.adam.models

import edu.berkeley.cs.amplab.adam.rdd.AdamContext._
import edu.berkeley.cs.amplab.adam.avro.ADAMRecord
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import scala.collection.immutable._
import scala.collection.mutable
import java.io.File

class SnpTable(private val table: Map[String, Set[Long]]) extends Serializable with Logging {
log.info("SNP table has %s contigs and %s entries".format(table.size, table.values.map(_.size).sum))

def isMaskedAtReadOffset(read: ADAMRecord, offset: Int): Boolean = {
val position = read.readOffsetToReferencePosition(offset)
try {
position.isEmpty || table(read.getReferenceName.toString).contains(position.get)
} catch {
case e: java.util.NoSuchElementException =>
false
}
}
}

object SnpTable {
def apply(): SnpTable = {
new SnpTable(Map[String, Set[Long]]())
}

// `dbSNP` is expected to be a sites-only VCF
def apply(dbSNP: File): SnpTable = {
// parse into tuples of (contig, position)
val lines = scala.io.Source.fromFile(dbSNP).getLines()
val tuples = lines.filter(line => !line.startsWith("#")).map(line => {
val split = line.split("\t")
val contig = split(0)
val pos = split(1).toLong
(contig, pos)
})
// construct map from contig to set of positions
// this is done in-place to reduce overhead
val table = new mutable.HashMap[String, mutable.HashSet[Long]]
tuples.foreach(tup => table.getOrElseUpdate(tup._1, { new mutable.HashSet[Long] }) += tup._2)
// construct SnpTable from immutable copy of `table`
new SnpTable(table.mapValues(_.toSet).toMap)
}

/*
def apply(lines: RDD[String]): SnpTable = {
// parse into tuples of (contig, position)
val tuples = lines.filter(line => !line.startsWith("#")).map(line => {
val split = line.split("\t")
val contig = split(0)
val pos = split(1).toLong
(contig, pos)
})
// construct map from contig to set of positions
val table = tuples.groupByKey.collect.toMap.mapValues(_.toSet)
new SnpTable(table)
}
*/
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import parquet.avro.{AvroParquetOutputFormat, AvroWriteSupport}
import parquet.hadoop.util.ContextUtil
import org.apache.avro.specific.SpecificRecord
import edu.berkeley.cs.amplab.adam.avro.{ADAMPileup, ADAMRecord, ADAMVariant, ADAMGenotype, ADAMVariantDomain}
import edu.berkeley.cs.amplab.adam.models.{SequenceRecord, SequenceDictionary, SingleReadBucket, ReferencePosition, ADAMRod}
import edu.berkeley.cs.amplab.adam.models.{SequenceRecord, SequenceDictionary, SingleReadBucket, SnpTable, ReferencePosition, ADAMRod}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.Logging
Expand Down Expand Up @@ -100,18 +100,8 @@ class AdamRecordRDDFunctions(rdd: RDD[ADAMRecord]) extends Serializable with Log
MarkDuplicates(rdd)
}

def adamBQSR(dbSNP: File): RDD[ADAMRecord] = {
val dbsnpMap = scala.io.Source.fromFile(dbSNP).getLines().map(posLine => {
val split = posLine.split("\t")
val contig = split(0)
val pos = split(1).toInt
(contig, pos)
}).foldLeft(Map[String, Set[Int]]())((dbMap, pair) => {
dbMap + (pair._1 -> (dbMap.getOrElse(pair._1, Set[Int]()) + pair._2))
})

val broadcastDbSNP = rdd.context.broadcast(dbsnpMap)

def adamBQSR(dbSNP: SnpTable): RDD[ADAMRecord] = {
val broadcastDbSNP = rdd.context.broadcast(dbSNP)
RecalibrateBaseQualities(rdd, broadcastDbSNP)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@ package edu.berkeley.cs.amplab.adam.rdd
import org.apache.spark.Logging
import org.apache.spark.broadcast.{Broadcast => SparkBroadcast}
import edu.berkeley.cs.amplab.adam.avro.ADAMRecord
import edu.berkeley.cs.amplab.adam.rich.RichADAMRecord
import edu.berkeley.cs.amplab.adam.rich.RichADAMRecord._
import edu.berkeley.cs.amplab.adam.models.SnpTable
import edu.berkeley.cs.amplab.adam.rdd.recalibration._
import org.apache.spark.rdd.RDD

private[rdd] object RecalibrateBaseQualities extends Serializable with Logging {

def usableRead(read: ADAMRecord): Boolean = {
def usableRead(read: RichADAMRecord): Boolean = {
// todo -- the mismatchingPositions should not merely be a filter, it should result in an exception. These are required for calculating mismatches.
read.getReadMapped && read.getPrimaryAlignment && !read.getDuplicateRead && (read.getMismatchingPositions != null)
}

def apply(rdd: RDD[ADAMRecord], dbsnp: SparkBroadcast[Map[String, Set[Int]]]): RDD[ADAMRecord] = {
def apply(poorRdd: RDD[ADAMRecord], dbsnp: SparkBroadcast[SnpTable]): RDD[ADAMRecord] = {
val rdd = poorRdd.map(new RichADAMRecord(_))
// initialize the covariates
println("Instantiating covariates...")
val qualByRG = new QualByRG(rdd)
Expand All @@ -45,7 +49,7 @@ private[rdd] object RecalibrateBaseQualities extends Serializable with Logging {
private[rdd] class RecalibrateBaseQualities(val qualCovar: QualByRG, val covars: List[StandardCovariate]) extends Serializable with Logging {
initLogging()

def computeTable(rdd: RDD[ADAMRecord], dbsnp: SparkBroadcast[Map[String, Set[Int]]]): RecalTable = {
def computeTable(rdd: RDD[RichADAMRecord], dbsnp: SparkBroadcast[SnpTable]): RecalTable = {

def addCovariates(table: RecalTable, covar: ReadCovariates): RecalTable = {
//log.info("Aggregating covarates for read "+covar.read.record.getReadName.toString)
Expand All @@ -57,12 +61,12 @@ private[rdd] class RecalibrateBaseQualities(val qualCovar: QualByRG, val covars:
table1 ++ table2
}

rdd.map(r => ReadCovariates(r, qualCovar, covars, dbsnp)).aggregate(new RecalTable)(addCovariates, mergeTables)
rdd.map(r => ReadCovariates(r, qualCovar, covars, dbsnp.value)).aggregate(new RecalTable)(addCovariates, mergeTables)
}

def applyTable(table: RecalTable, rdd: RDD[ADAMRecord], qualCovar: QualByRG, covars: List[StandardCovariate]): RDD[ADAMRecord] = {
def applyTable(table: RecalTable, rdd: RDD[RichADAMRecord], qualCovar: QualByRG, covars: List[StandardCovariate]): RDD[ADAMRecord] = {
table.finalizeTable()
def recalibrate(record: ADAMRecord): ADAMRecord = {
def recalibrate(record: RichADAMRecord): ADAMRecord = {
if (!record.getReadMapped || !record.getPrimaryAlignment || record.getDuplicateRead) {
record // no need to recalibrate these records todo -- enable optional recalibration of all reads
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@
package edu.berkeley.cs.amplab.adam.rdd.recalibration

import edu.berkeley.cs.amplab.adam.rdd.AdamContext._
import edu.berkeley.cs.amplab.adam.avro.ADAMRecord
import org.apache.spark.broadcast.{Broadcast => SparkBroadcast}
import edu.berkeley.cs.amplab.adam.rich.RichADAMRecord
import edu.berkeley.cs.amplab.adam.rich.RichADAMRecord._
import edu.berkeley.cs.amplab.adam.models.SnpTable

object ReadCovariates {
def apply(rec: ADAMRecord, qualRG: QualByRG, covars: List[StandardCovariate], dbsnp: SparkBroadcast[Map[String, Set[Int]]] = null): ReadCovariates = {
def apply(rec: RichADAMRecord, qualRG: QualByRG, covars: List[StandardCovariate],
dbsnp: SnpTable = SnpTable()): ReadCovariates = {
new ReadCovariates(rec, qualRG, covars, dbsnp)
}
}

class ReadCovariates(val read: ADAMRecord, qualByRG: QualByRG, covars: List[StandardCovariate],
val dbsnp: SparkBroadcast[Map[String, Set[Int]]] = null) extends Iterator[BaseCovariates] with Serializable {
class ReadCovariates(val read: RichADAMRecord, qualByRG: QualByRG, covars: List[StandardCovariate],
val dbSNP: SnpTable) extends Iterator[BaseCovariates] with Serializable {

val startOffset = read.qualityScores.takeWhile(_ <= 2).size
val endOffset = read.qualityScores.size - read.qualityScores.reverseIterator.takeWhile(_ <= 2).size
Expand All @@ -38,15 +40,15 @@ class ReadCovariates(val read: ADAMRecord, qualByRG: QualByRG, covars: List[Stan
override def hasNext: Boolean = iter_position < endOffset

override def next(): BaseCovariates = {
val idx = (iter_position - startOffset).toInt
val position = read.getPosition(idx)
val isMasked = dbsnp == null || position.isEmpty ||
dbsnp.value(read.getReferenceName.toString).contains(position.get.toInt) ||
read.isMismatchBase(idx).isEmpty
val isMisMatch = read.isMismatchBase(idx).getOrElse(false) // getOrElse because reads without an MD tag can appear during *application* of recal table
val offset = (iter_position - startOffset).toInt
val mismatch = read.isMismatchAtReadOffset(offset)
// FIXME: why does empty mismatch mean it should be masked?
val isMasked = dbSNP.isMaskedAtReadOffset(read, offset) || mismatch.isEmpty
// getOrElse because reads without an MD tag can appear during *application* of recal table
val isMismatch = mismatch.getOrElse(false)
iter_position += 1
new BaseCovariates(qualCovar(idx), requestedCovars.map(v => v(idx)).toArray,
read.qualityScores(idx), isMisMatch, isMasked)
new BaseCovariates(qualCovar(offset), requestedCovars.map(v => v(offset)).toArray,
read.qualityScores(offset), isMismatch, isMasked)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package edu.berkeley.cs.amplab.adam.rdd.recalibration

import edu.berkeley.cs.amplab.adam.avro.ADAMRecord
import edu.berkeley.cs.amplab.adam.rich.RichADAMRecord
import edu.berkeley.cs.amplab.adam.rich.RichADAMRecord._

object RecalUtil extends Serializable {

Expand All @@ -29,7 +31,7 @@ object RecalUtil extends Serializable {

def errorProbToQual(d: Double): Byte = (-10 * math.log10(d)).toInt.toByte

def recalibrate(read: ADAMRecord, qualByRG: QualByRG, covars: List[StandardCovariate], table: RecalTable): ADAMRecord = {
def recalibrate(read: RichADAMRecord, qualByRG: QualByRG, covars: List[StandardCovariate], table: RecalTable): ADAMRecord = {
// get the covariates
val readCovariates = ReadCovariates(read, qualByRG, covars)
val toQual = errorProbToQual _
Expand All @@ -42,4 +44,4 @@ object RecalUtil extends Serializable {
builder.build()
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ package edu.berkeley.cs.amplab.adam.rdd.recalibration

import edu.berkeley.cs.amplab.adam.rdd.AdamContext._
import edu.berkeley.cs.amplab.adam.avro.ADAMRecord
import edu.berkeley.cs.amplab.adam.rich.RichADAMRecord
import edu.berkeley.cs.amplab.adam.rich.RichADAMRecord._
import org.apache.spark.rdd.RDD

// this class is required, not just standard. Baked in to recalibration.
class QualByRG(rdd: RDD[ADAMRecord]) extends Serializable {
class QualByRG(rdd: RDD[RichADAMRecord]) extends Serializable {
// need to get the unique read groups todo --- this is surprisingly slow
//val readGroups = rdd.map(_.getRecordGroupId.toString).distinct().collect().sorted.zipWithIndex.toMap
var readGroups = Map[String, Int]()

def apply(read: ADAMRecord, start: Int, end: Int): Array[Int] = {
def apply(read: RichADAMRecord, start: Int, end: Int): Array[Int] = {
if (!readGroups.contains(read.getRecordGroupId.asInstanceOf[String])) {
readGroups += (read.getRecordGroupId.asInstanceOf[String] -> readGroups.size)
}
Expand All @@ -37,47 +39,47 @@ class QualByRG(rdd: RDD[ADAMRecord]) extends Serializable {
}

trait StandardCovariate extends Serializable {
def apply(read: ADAMRecord, start: Int, end: Int): Array[Int] // get the covariate for all the bases of the read
def apply(read: RichADAMRecord, start: Int, end: Int): Array[Int] // get the covariate for all the bases of the read
}

case class DiscreteCycle(args: RDD[ADAMRecord]) extends StandardCovariate {
case class DiscreteCycle(args: RDD[RichADAMRecord]) extends StandardCovariate {
// this is a special-case of the GATK's Cycle covariate for discrete technologies.
// Not to be used for 454 or ion torrent (which are flow cycles)
def apply(read: ADAMRecord, startOffset: Int, endOffset: Int): Array[Int] = {
def apply(read: RichADAMRecord, startOffset: Int, endOffset: Int): Array[Int] = {
var cycles: Array[Int] = if (read.getReadNegativeStrand) Range(read.getSequence.toString.size, 0, -1).toArray
else Range(1, 1 + read.getSequence.toString.size, 1).toArray
cycles = if (read.getReadPaired && read.getSecondOfPair) cycles.map(-_) else cycles
cycles.slice(startOffset, endOffset)
}
}

case class BaseContext(records: RDD[ADAMRecord], size: Int) extends StandardCovariate {
case class BaseContext(records: RDD[RichADAMRecord], size: Int) extends StandardCovariate {
def this(_s: Int) = this(null, _s)

def this(_r: RDD[ADAMRecord]) = this(_r, 2)
def this(_r: RDD[RichADAMRecord]) = this(_r, 2)

val BASES = Array('A'.toByte, 'C'.toByte, 'G'.toByte, 'T'.toByte)
val COMPL = Array('T'.toByte, 'G'.toByte, 'C'.toByte, 'A'.toByte)
val N_BASE = 'N'.toByte
val COMPL_MP = (BASES zip COMPL toMap) + (N_BASE -> N_BASE)

def apply(read: ADAMRecord, startOffset: Int, endOffset: Int): Array[Int] = {
def apply(read: RichADAMRecord, startOffset: Int, endOffset: Int): Array[Int] = {
// the context of a covariate is the previous @size bases, though "previous" depends on
// how the read was aligned (negative strand is reverse-complemented).
if (read.getReadNegativeStrand) reverseContext(read, startOffset, endOffset) else forwardContext(read, startOffset, endOffset)
}

// note: the last base is dropped from the construction of contexts because it is not
// present in any context - just as the first base cannot have a context assigned to it.
def forwardContext(rec: ADAMRecord, st: Int, end: Int): Array[Int] = {
def forwardContext(rec: RichADAMRecord, st: Int, end: Int): Array[Int] = {
getContext(rec.getSequence.asInstanceOf[String].toCharArray.map(_.toByte).slice(st, end))
}

def simpleReverseComplement(bases: Array[Byte]): Array[Byte] = {
bases.map(b => COMPL_MP(b)).reverseIterator.toArray
}

def reverseContext(rec: ADAMRecord, st: Int, end: Int): Array[Int] = {
def reverseContext(rec: RichADAMRecord, st: Int, end: Int): Array[Int] = {
// first reverse-complement the sequence
val baseSeq = simpleReverseComplement(rec.getSequence.asInstanceOf[String].toCharArray.map(_.toByte))
getContext(baseSeq.slice(baseSeq.size - end, baseSeq.size - st))
Expand Down
Loading