diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index b76dc5f93193..0a0140803073 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -30,7 +30,9 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.ml.param.shared.HasCheckpointInterval import org.apache.spark.ml.util._ +import org.apache.spark.mllib.impl.PeriodicDatasetCheckpointer import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -94,11 +96,19 @@ abstract class PipelineStage extends Params with Logging { */ @Since("1.2.0") class Pipeline @Since("1.4.0") ( - @Since("1.4.0") override val uid: String) extends Estimator[PipelineModel] with MLWritable { + @Since("1.4.0") override val uid: String) + extends Estimator[PipelineModel] with MLWritable with HasCheckpointInterval { @Since("1.4.0") def this() = this(Identifiable.randomUID("pipeline")) + /** @group setParam */ + @Since("2.2.0") + def setCheckpointInterval(value: Int): this.type = { + set(checkpointInterval, value) + this + } + /** * param for pipeline stages * @group param @@ -144,10 +154,17 @@ class Pipeline @Since("1.4.0") ( case _ => } } + val checkpointer = if (isDefined(checkpointInterval)) { + Some(new PeriodicDatasetCheckpointer( + getCheckpointInterval, dataset.sparkSession.sparkContext)) + } else { + None + } var curDataset = dataset val transformers = ListBuffer.empty[Transformer] theStages.view.zipWithIndex.foreach { case (stage, index) => if (index <= indexOfLastEstimator) { + checkpointer.foreach(_.update(curDataset)) val transformer = stage match { case estimator: Estimator[_] => estimator.fit(curDataset) @@ -166,7 +183,11 @@ class Pipeline @Since("1.4.0") ( } } - new PipelineModel(uid, transformers.toArray).setParent(this) + val pipelineModel = new PipelineModel(uid, transformers.toArray).setParent(this) + if (isDefined(checkpointInterval)) { + pipelineModel.setCheckpointInterval(getCheckpointInterval) + } + pipelineModel } @Since("1.4.0") @@ -292,17 +313,34 @@ object Pipeline extends MLReadable[Pipeline] { class PipelineModel private[ml] ( @Since("1.4.0") override val uid: String, @Since("1.4.0") val stages: Array[Transformer]) - extends Model[PipelineModel] with MLWritable with Logging { + extends Model[PipelineModel] with MLWritable with HasCheckpointInterval with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { this(uid, stages.asScala.toArray) } + /** @group setParam */ + @Since("2.2.0") + def setCheckpointInterval(value: Int): this.type = { + set(checkpointInterval, value) + this + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur)) + val checkpointer = if (isDefined(checkpointInterval)) { + Some(new PeriodicDatasetCheckpointer( + getCheckpointInterval, dataset.sparkSession.sparkContext)) + } else { + None + } + stages.foldLeft(dataset.toDF)((cur, transformer) => { + val newDF = transformer.transform(cur) + checkpointer.foreach(_.update(newDF)) + newDF + }) } @Since("1.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicDatasetCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicDatasetCheckpointer.scala new file mode 100644 index 000000000000..f15af24cdfcb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicDatasetCheckpointer.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.spark.SparkContext +import org.apache.spark.sql.Dataset +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing Datasets. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new Dataset has been created and before the Dataset has been + * materialized. Because we call Dataset.checkout() with eager as true, it means it always performs + * a RDD.count() after calling RDD.checkpoint() on the underlying RDD of the Dataset, so different + * than [[PeriodicRDDCheckpointer]], after updating [[PeriodicDatasetCheckpointer]], users do not + * need to materialize the Dataset to ensure that persisting and checkpointing actually occur. + * + * When update() is called, this does the following: + * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. + * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Datasets should be + * checkpointed). + * - This class removes checkpoint files once later other Datasets have been checkpointed. + * However, references to the RDDs of the older Datasets will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (ds1, ds2, ds3, ...) = ... + * val cp = new PeriodicDatasetCheckpointer(2, sc) + * cp.update(ds1) + * // persisted: rdd1 + * cp.update(ds2) + * // persisted: ds1, ds2 + * // checkpointed: ds2 + * cp.update(ds3) + * // persisted: ds1, ds2, ds3 + * // checkpointed: ds2 + * cp.update(ds4) + * // persisted: ds2, ds3, ds4 + * // checkpointed: ds4 + * cp.update(ds5) + * // persisted: ds3, ds4, ds5 + * // checkpointed: ds4 + * }}} + * + * @param checkpointInterval Datasets will be checkpointed at this interval + * + * TODO: Move this out of MLlib? + */ +private[spark] class PeriodicDatasetCheckpointer( + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[Dataset[_]](checkpointInterval, sc) { + + override protected def checkpoint(data: Dataset[_]): Unit = data.checkpoint(eager = true) + + override protected def isCheckpointed(data: Dataset[_]): Boolean = data.isCheckpoint + + override protected def persist(data: Dataset[_]): Unit = { + if (data.storageLevel == StorageLevel.NONE) { + data.persist() + } + } + + override protected def unpersist(data: Dataset[_]): Unit = data.unpersist() + + override protected def getCheckpointFiles(data: Dataset[_]): Iterable[String] = { + data.queryExecution.toRdd.getCheckpointFile.map(x => x) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala index 145dc22b7428..386b95bd975e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -49,6 +49,7 @@ import org.apache.spark.storage.StorageLevel * {{{ * val (rdd1, rdd2, rdd3, ...) = ... * val cp = new PeriodicRDDCheckpointer(2, sc) + * cp.update(rdd1) * rdd1.count(); * // persisted: rdd1 * cp.update(rdd2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index dafc6c200f95..1480269ff2ec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,13 +26,14 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite -import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} +import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler, OneHotEncoder, StringIndexer} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -227,6 +228,53 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val steps = stages0 ++ stages1 val p = new Pipeline().setStages(steps) } + + test("Pipeline checkpoint interval") { + def fitPipeline(doCheckpoint: Boolean): Dataset[_] = { + val optTempDir = if (doCheckpoint) { + Some(Utils.createTempDir()) + } else { + None + } + + optTempDir.foreach { tempDir => + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + } + + val df = (1 to 5).foldLeft(Seq((1, "foo"), (2, "bar"), (3, "baz")).toDF("id", "x0"))( + (df, i) => df.withColumn(s"x$i", $"x0")) + val indexers = df.columns.tail.map(c => new StringIndexer() + .setInputCol(c) + .setOutputCol(s"${c}_indexed") + .setHandleInvalid("skip")) + + val encoders = indexers.map(indexer => new OneHotEncoder() + .setInputCol(indexer.getOutputCol) + .setOutputCol(s"${indexer.getOutputCol}_encoded") + .setDropLast(true)) + + val stages: Array[PipelineStage] = indexers ++ encoders + val pipeline = new Pipeline().setStages(stages) + if (doCheckpoint) { + pipeline.setCheckpointInterval(2) + } + + val outputDF = pipeline.fit(df).transform(df) + + optTempDir.foreach { tempDir => + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + outputDF + } + val outputWithoutCheckpoint = fitPipeline(false).collect() + val outputWithCheckpoint = fitPipeline(true).collect() + assert(outputWithoutCheckpoint.length == outputWithCheckpoint.length) + outputWithoutCheckpoint.zip(outputWithCheckpoint).foreach { case (noCheckout, checkout) => + assert(noCheckout === checkout) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 391c34f1285e..2cf8784f057f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -493,6 +493,10 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming + @Experimental + @InterfaceStability.Evolving + def isCheckpoint: Boolean = queryExecution.toRdd.isCheckpointed + /** * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate * the logical plan of this Dataset, which is especially useful in iterative algorithms where the