From 5ed5c2a65c31c78b7845bbb8a3ef859590453ba9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Feb 2017 04:49:04 +0000 Subject: [PATCH 1/3] Periodic checkout datasets for long ml pipeline. --- .../scala/org/apache/spark/ml/Pipeline.scala | 64 ++++++++++++- .../impl/PeriodicDatasetCheckpointer.scala | 93 +++++++++++++++++++ .../mllib/impl/PeriodicRDDCheckpointer.scala | 1 + .../scala/org/apache/spark/sql/Dataset.scala | 4 + 4 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicDatasetCheckpointer.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index b76dc5f93193..be3c58ab32ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util._ +import org.apache.spark.mllib.impl.PeriodicDatasetCheckpointer import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -99,6 +100,25 @@ class Pipeline @Since("1.4.0") ( @Since("1.4.0") def this() = this(Identifiable.randomUID("pipeline")) + /** + * param for checkpoint interval of pipeline stages + * @group param + */ + @Since("2.2.0") + val checkpointInterval: Param[Int] = + new Param(this, "checkpointInterval", "checkpoint interval for stages") + + /** @group setParam */ + @Since("2.2.0") + def setCheckpointInterval(value: Int): this.type = { + set(checkpointInterval, value) + this + } + + /** @group getParam */ + @Since("2.2.0") + def getCheckpointInterval: Int = $(checkpointInterval) + /** * param for pipeline stages * @group param @@ -144,10 +164,17 @@ class Pipeline @Since("1.4.0") ( case _ => } } + val checkpointer = if (isDefined(checkpointInterval)) { + Some(new PeriodicDatasetCheckpointer( + getCheckpointInterval, dataset.sparkSession.sparkContext)) + } else { + None + } var curDataset = dataset val transformers = ListBuffer.empty[Transformer] theStages.view.zipWithIndex.foreach { case (stage, index) => if (index <= indexOfLastEstimator) { + checkpointer.foreach(_.update(curDataset)) val transformer = stage match { case estimator: Estimator[_] => estimator.fit(curDataset) @@ -166,7 +193,11 @@ class Pipeline @Since("1.4.0") ( } } - new PipelineModel(uid, transformers.toArray).setParent(this) + val pipelineModel = new PipelineModel(uid, transformers.toArray).setParent(this) + if (isDefined(checkpointInterval)) { + pipelineModel.setCheckpointInterval(getCheckpointInterval) + } + pipelineModel } @Since("1.4.0") @@ -299,10 +330,39 @@ class PipelineModel private[ml] ( this(uid, stages.asScala.toArray) } + /** + * param for checkpoint interval of pipeline stages + * @group param + */ + @Since("2.2.0") + val checkpointInterval: Param[Int] = + new Param(this, "checkpointInterval", "checkpoint interval for stages") + + /** @group setParam */ + @Since("2.2.0") + def setCheckpointInterval(value: Int): this.type = { + set(checkpointInterval, value) + this + } + + /** @group getParam */ + @Since("2.2.0") + def getCheckpointInterval: Int = $(checkpointInterval) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur)) + val checkpointer = if (isDefined(checkpointInterval)) { + Some(new PeriodicDatasetCheckpointer( + getCheckpointInterval, dataset.sparkSession.sparkContext)) + } else { + None + } + stages.foldLeft(dataset.toDF)((cur, transformer) => { + val newDF = transformer.transform(cur) + checkpointer.foreach(_.update(newDF)) + newDF + }) } @Since("1.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicDatasetCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicDatasetCheckpointer.scala new file mode 100644 index 000000000000..f15af24cdfcb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicDatasetCheckpointer.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.spark.SparkContext +import org.apache.spark.sql.Dataset +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing Datasets. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new Dataset has been created and before the Dataset has been + * materialized. Because we call Dataset.checkout() with eager as true, it means it always performs + * a RDD.count() after calling RDD.checkpoint() on the underlying RDD of the Dataset, so different + * than [[PeriodicRDDCheckpointer]], after updating [[PeriodicDatasetCheckpointer]], users do not + * need to materialize the Dataset to ensure that persisting and checkpointing actually occur. + * + * When update() is called, this does the following: + * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. + * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Datasets should be + * checkpointed). + * - This class removes checkpoint files once later other Datasets have been checkpointed. + * However, references to the RDDs of the older Datasets will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (ds1, ds2, ds3, ...) = ... + * val cp = new PeriodicDatasetCheckpointer(2, sc) + * cp.update(ds1) + * // persisted: rdd1 + * cp.update(ds2) + * // persisted: ds1, ds2 + * // checkpointed: ds2 + * cp.update(ds3) + * // persisted: ds1, ds2, ds3 + * // checkpointed: ds2 + * cp.update(ds4) + * // persisted: ds2, ds3, ds4 + * // checkpointed: ds4 + * cp.update(ds5) + * // persisted: ds3, ds4, ds5 + * // checkpointed: ds4 + * }}} + * + * @param checkpointInterval Datasets will be checkpointed at this interval + * + * TODO: Move this out of MLlib? + */ +private[spark] class PeriodicDatasetCheckpointer( + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[Dataset[_]](checkpointInterval, sc) { + + override protected def checkpoint(data: Dataset[_]): Unit = data.checkpoint(eager = true) + + override protected def isCheckpointed(data: Dataset[_]): Boolean = data.isCheckpoint + + override protected def persist(data: Dataset[_]): Unit = { + if (data.storageLevel == StorageLevel.NONE) { + data.persist() + } + } + + override protected def unpersist(data: Dataset[_]): Unit = data.unpersist() + + override protected def getCheckpointFiles(data: Dataset[_]): Iterable[String] = { + data.queryExecution.toRdd.getCheckpointFile.map(x => x) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala index 145dc22b7428..386b95bd975e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -49,6 +49,7 @@ import org.apache.spark.storage.StorageLevel * {{{ * val (rdd1, rdd2, rdd3, ...) = ... * val cp = new PeriodicRDDCheckpointer(2, sc) + * cp.update(rdd1) * rdd1.count(); * // persisted: rdd1 * cp.update(rdd2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 391c34f1285e..2cf8784f057f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -493,6 +493,10 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming + @Experimental + @InterfaceStability.Evolving + def isCheckpoint: Boolean = queryExecution.toRdd.isCheckpointed + /** * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate * the logical plan of this Dataset, which is especially useful in iterative algorithms where the From 7a1b3008a5873600016ebe0649285a724c6f4d7c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Feb 2017 07:14:38 +0000 Subject: [PATCH 2/3] Reuse HasCheckpointInterval trait which already defines checkpoint interval param. --- .../scala/org/apache/spark/ml/Pipeline.scala | 30 +++---------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index be3c58ab32ff..0a0140803073 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -30,6 +30,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.ml.param.shared.HasCheckpointInterval import org.apache.spark.ml.util._ import org.apache.spark.mllib.impl.PeriodicDatasetCheckpointer import org.apache.spark.sql.{DataFrame, Dataset} @@ -95,19 +96,12 @@ abstract class PipelineStage extends Params with Logging { */ @Since("1.2.0") class Pipeline @Since("1.4.0") ( - @Since("1.4.0") override val uid: String) extends Estimator[PipelineModel] with MLWritable { + @Since("1.4.0") override val uid: String) + extends Estimator[PipelineModel] with MLWritable with HasCheckpointInterval { @Since("1.4.0") def this() = this(Identifiable.randomUID("pipeline")) - /** - * param for checkpoint interval of pipeline stages - * @group param - */ - @Since("2.2.0") - val checkpointInterval: Param[Int] = - new Param(this, "checkpointInterval", "checkpoint interval for stages") - /** @group setParam */ @Since("2.2.0") def setCheckpointInterval(value: Int): this.type = { @@ -115,10 +109,6 @@ class Pipeline @Since("1.4.0") ( this } - /** @group getParam */ - @Since("2.2.0") - def getCheckpointInterval: Int = $(checkpointInterval) - /** * param for pipeline stages * @group param @@ -323,21 +313,13 @@ object Pipeline extends MLReadable[Pipeline] { class PipelineModel private[ml] ( @Since("1.4.0") override val uid: String, @Since("1.4.0") val stages: Array[Transformer]) - extends Model[PipelineModel] with MLWritable with Logging { + extends Model[PipelineModel] with MLWritable with HasCheckpointInterval with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { this(uid, stages.asScala.toArray) } - /** - * param for checkpoint interval of pipeline stages - * @group param - */ - @Since("2.2.0") - val checkpointInterval: Param[Int] = - new Param(this, "checkpointInterval", "checkpoint interval for stages") - /** @group setParam */ @Since("2.2.0") def setCheckpointInterval(value: Int): this.type = { @@ -345,10 +327,6 @@ class PipelineModel private[ml] ( this } - /** @group getParam */ - @Since("2.2.0") - def getCheckpointInterval: Int = $(checkpointInterval) - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) From 32c90dd0817778d3a1a0d1a955463d656dd92d60 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Feb 2017 07:48:57 +0000 Subject: [PATCH 3/3] Add test case to verify the correctness of result. --- .../org/apache/spark/ml/PipelineSuite.scala | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index dafc6c200f95..1480269ff2ec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,13 +26,14 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite -import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} +import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler, OneHotEncoder, StringIndexer} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -227,6 +228,53 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val steps = stages0 ++ stages1 val p = new Pipeline().setStages(steps) } + + test("Pipeline checkpoint interval") { + def fitPipeline(doCheckpoint: Boolean): Dataset[_] = { + val optTempDir = if (doCheckpoint) { + Some(Utils.createTempDir()) + } else { + None + } + + optTempDir.foreach { tempDir => + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + } + + val df = (1 to 5).foldLeft(Seq((1, "foo"), (2, "bar"), (3, "baz")).toDF("id", "x0"))( + (df, i) => df.withColumn(s"x$i", $"x0")) + val indexers = df.columns.tail.map(c => new StringIndexer() + .setInputCol(c) + .setOutputCol(s"${c}_indexed") + .setHandleInvalid("skip")) + + val encoders = indexers.map(indexer => new OneHotEncoder() + .setInputCol(indexer.getOutputCol) + .setOutputCol(s"${indexer.getOutputCol}_encoded") + .setDropLast(true)) + + val stages: Array[PipelineStage] = indexers ++ encoders + val pipeline = new Pipeline().setStages(stages) + if (doCheckpoint) { + pipeline.setCheckpointInterval(2) + } + + val outputDF = pipeline.fit(df).transform(df) + + optTempDir.foreach { tempDir => + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + outputDF + } + val outputWithoutCheckpoint = fitPipeline(false).collect() + val outputWithCheckpoint = fitPipeline(true).collect() + assert(outputWithoutCheckpoint.length == outputWithCheckpoint.length) + outputWithoutCheckpoint.zip(outputWithCheckpoint).foreach { case (noCheckout, checkout) => + assert(noCheckout === checkout) + } + } }