From 9962d5ffccd7778f924f81f790914ba5c2638268 Mon Sep 17 00:00:00 2001 From: Frank Austin Nothaft Date: Wed, 18 Nov 2015 11:49:47 -0800 Subject: [PATCH] [ADAM-883] Add caching to Transform pipeline. The Transform pipeline in the CLI has several stages (e.g., sort, indel realignment, BQSR) that trigger recomputation. If you are running a single stage off of local storage/HDFS/Tachyon, this is OK. However, if you're running multiple stages, or you are loading data from S3/etc, this can lead to serious performance degradation. To address this, I've added the proper caching statements. Additionally, I've added a hook so that the user can specify the storage level to use for caching. Resolves #883. --- .../org/bdgenomics/adam/cli/Transform.scala | 45 +++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala b/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala index 1407c05f31..1be8e96f38 100644 --- a/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala +++ b/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala @@ -19,8 +19,9 @@ package org.bdgenomics.adam.cli import htsjdk.samtools.ValidationStringency import org.apache.parquet.filter2.dsl.Dsl._ -import org.apache.spark.rdd.RDD import org.apache.spark.{ Logging, SparkContext } +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.bdgenomics.adam.algorithms.consensus._ import org.bdgenomics.adam.instrumentation.Timers._ import org.bdgenomics.adam.models.SnpTable @@ -105,6 +106,10 @@ class TransformArgs extends Args4jBase with ADAMSaveAnyArgs with ParquetArgs { var mdTagsFragmentSize: Long = 1000000L @Args4jOption(required = false, name = "-md_tag_overwrite", usage = "When adding MD tags to reads, overwrite existing incorrect tags.") var mdTagsOverwrite: Boolean = false + @Args4jOption(required = false, name = "-cache", usage = "Cache data to avoid recomputing between stages.") + var cache: Boolean = false + @Args4jOption(required = false, name = "-storage_level", usage = "Set the storage level to use for caching.") + var storageLevel: String = "MEMORY_ONLY" } class Transform(protected val args: TransformArgs) extends BDGSparkCommand[TransformArgs] with Logging { @@ -116,6 +121,7 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans var adamRecords = rdd val sc = rdd.context + val sl = StorageLevel.fromString(args.storageLevel) val stringencyOpt = Option(args.stringency).map(ValidationStringency.valueOf(_)) @@ -130,12 +136,18 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans } if (args.locallyRealign) { + val oldRdd = if (args.cache) { + adamRecords.persist(sl) + } else { + adamRecords + } + log.info("Locally realigning indels.") val consensusGenerator = Option(args.knownIndelsFile) .fold(new ConsensusGeneratorFromReads().asInstanceOf[ConsensusGenerator])( new ConsensusGeneratorFromKnowns(_, sc).asInstanceOf[ConsensusGenerator]) - adamRecords = adamRecords.adamRealignIndels( + adamRecords = oldRdd.adamRealignIndels( consensusGenerator, isSorted = false, args.maxIndelSize, @@ -143,16 +155,31 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans args.lodThreshold, args.maxTargetSize ) + + if (args.cache) { + oldRdd.unpersist() + } } if (args.recalibrateBaseQualities) { log.info("Recalibrating base qualities") + + val oldRdd = if (args.cache) { + adamRecords.persist(sl) + } else { + adamRecords + } + val knownSnps: SnpTable = createKnownSnpsTable(sc) - adamRecords = adamRecords.adamBQSR( + adamRecords = oldRdd.adamBQSR( sc.broadcast(knownSnps), Option(args.observationsPath), stringency ) + + if (args.cache) { + oldRdd.unpersist() + } } if (args.coalesce != -1) { @@ -166,8 +193,18 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans // NOTE: For now, sorting needs to be the last transform if (args.sortReads) { + val oldRdd = if (args.cache) { + adamRecords.persist(sl) + } else { + adamRecords + } + log.info("Sorting reads") - adamRecords = adamRecords.adamSortReadsByReferencePosition() + adamRecords = oldRdd.adamSortReadsByReferencePosition() + + if (args.cache) { + oldRdd.unpersist() + } } if (args.mdTagsReferenceFile != null) {