Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.param.shared.HasCheckpointInterval
import org.apache.spark.ml.util._
import org.apache.spark.mllib.impl.PeriodicDatasetCheckpointer
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -94,11 +96,19 @@ abstract class PipelineStage extends Params with Logging {
*/
@Since("1.2.0")
class Pipeline @Since("1.4.0") (
@Since("1.4.0") override val uid: String) extends Estimator[PipelineModel] with MLWritable {
@Since("1.4.0") override val uid: String)
extends Estimator[PipelineModel] with MLWritable with HasCheckpointInterval {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("pipeline"))

/** @group setParam */
@Since("2.2.0")
def setCheckpointInterval(value: Int): this.type = {
set(checkpointInterval, value)
this
}

/**
* param for pipeline stages
* @group param
Expand Down Expand Up @@ -144,10 +154,17 @@ class Pipeline @Since("1.4.0") (
case _ =>
}
}
val checkpointer = if (isDefined(checkpointInterval)) {
Some(new PeriodicDatasetCheckpointer(
getCheckpointInterval, dataset.sparkSession.sparkContext))
} else {
None
}
var curDataset = dataset
val transformers = ListBuffer.empty[Transformer]
theStages.view.zipWithIndex.foreach { case (stage, index) =>
if (index <= indexOfLastEstimator) {
checkpointer.foreach(_.update(curDataset))
val transformer = stage match {
case estimator: Estimator[_] =>
estimator.fit(curDataset)
Expand All @@ -166,7 +183,11 @@ class Pipeline @Since("1.4.0") (
}
}

new PipelineModel(uid, transformers.toArray).setParent(this)
val pipelineModel = new PipelineModel(uid, transformers.toArray).setParent(this)
if (isDefined(checkpointInterval)) {
pipelineModel.setCheckpointInterval(getCheckpointInterval)
}
pipelineModel
}

@Since("1.4.0")
Expand Down Expand Up @@ -292,17 +313,34 @@ object Pipeline extends MLReadable[Pipeline] {
class PipelineModel private[ml] (
@Since("1.4.0") override val uid: String,
@Since("1.4.0") val stages: Array[Transformer])
extends Model[PipelineModel] with MLWritable with Logging {
extends Model[PipelineModel] with MLWritable with HasCheckpointInterval with Logging {

/** A Java/Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
this(uid, stages.asScala.toArray)
}

/** @group setParam */
@Since("2.2.0")
def setCheckpointInterval(value: Int): this.type = {
set(checkpointInterval, value)
this
}

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur))
val checkpointer = if (isDefined(checkpointInterval)) {
Some(new PeriodicDatasetCheckpointer(
getCheckpointInterval, dataset.sparkSession.sparkContext))
} else {
None
}
stages.foldLeft(dataset.toDF)((cur, transformer) => {
val newDF = transformer.transform(cur)
checkpointer.foreach(_.update(newDF))
newDF
})
}

@Since("1.2.0")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.impl

import org.apache.spark.SparkContext
import org.apache.spark.sql.Dataset
import org.apache.spark.storage.StorageLevel


/**
* This class helps with persisting and checkpointing Datasets.
* Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
* unpersisting and removing checkpoint files.
*
* Users should call update() when a new Dataset has been created and before the Dataset has been
* materialized. Because we call Dataset.checkout() with eager as true, it means it always performs
* a RDD.count() after calling RDD.checkpoint() on the underlying RDD of the Dataset, so different
* than [[PeriodicRDDCheckpointer]], after updating [[PeriodicDatasetCheckpointer]], users do not
* need to materialize the Dataset to ensure that persisting and checkpointing actually occur.
*
* When update() is called, this does the following:
* - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
* - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
* - If using checkpointing and the checkpoint interval has been reached,
* - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
* - Remove older checkpoints.
*
* WARNINGS:
* - This class should NOT be copied (since copies may conflict on which Datasets should be
* checkpointed).
* - This class removes checkpoint files once later other Datasets have been checkpointed.
* However, references to the RDDs of the older Datasets will still return isCheckpointed = true.
*
* Example usage:
* {{{
* val (ds1, ds2, ds3, ...) = ...
* val cp = new PeriodicDatasetCheckpointer(2, sc)
* cp.update(ds1)
* // persisted: rdd1
* cp.update(ds2)
* // persisted: ds1, ds2
* // checkpointed: ds2
* cp.update(ds3)
* // persisted: ds1, ds2, ds3
* // checkpointed: ds2
* cp.update(ds4)
* // persisted: ds2, ds3, ds4
* // checkpointed: ds4
* cp.update(ds5)
* // persisted: ds3, ds4, ds5
* // checkpointed: ds4
* }}}
*
* @param checkpointInterval Datasets will be checkpointed at this interval
*
* TODO: Move this out of MLlib?
*/
private[spark] class PeriodicDatasetCheckpointer(
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[Dataset[_]](checkpointInterval, sc) {

override protected def checkpoint(data: Dataset[_]): Unit = data.checkpoint(eager = true)

override protected def isCheckpointed(data: Dataset[_]): Boolean = data.isCheckpoint

override protected def persist(data: Dataset[_]): Unit = {
if (data.storageLevel == StorageLevel.NONE) {
data.persist()
}
}

override protected def unpersist(data: Dataset[_]): Unit = data.unpersist()

override protected def getCheckpointFiles(data: Dataset[_]): Iterable[String] = {
data.queryExecution.toRdd.getCheckpointFile.map(x => x)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import org.apache.spark.storage.StorageLevel
* {{{
* val (rdd1, rdd2, rdd3, ...) = ...
* val cp = new PeriodicRDDCheckpointer(2, sc)
* cp.update(rdd1)
* rdd1.count();
* // persisted: rdd1
* cp.update(rdd2)
Expand Down
50 changes: 49 additions & 1 deletion mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import org.scalatest.mock.MockitoSugar.mock

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler}
import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler, OneHotEncoder, StringIndexer}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

Expand Down Expand Up @@ -227,6 +228,53 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val steps = stages0 ++ stages1
val p = new Pipeline().setStages(steps)
}

test("Pipeline checkpoint interval") {
def fitPipeline(doCheckpoint: Boolean): Dataset[_] = {
val optTempDir = if (doCheckpoint) {
Some(Utils.createTempDir())
} else {
None
}

optTempDir.foreach { tempDir =>
val path = tempDir.toURI.toString
sc.setCheckpointDir(path)
}

val df = (1 to 5).foldLeft(Seq((1, "foo"), (2, "bar"), (3, "baz")).toDF("id", "x0"))(
(df, i) => df.withColumn(s"x$i", $"x0"))
val indexers = df.columns.tail.map(c => new StringIndexer()
.setInputCol(c)
.setOutputCol(s"${c}_indexed")
.setHandleInvalid("skip"))

val encoders = indexers.map(indexer => new OneHotEncoder()
.setInputCol(indexer.getOutputCol)
.setOutputCol(s"${indexer.getOutputCol}_encoded")
.setDropLast(true))

val stages: Array[PipelineStage] = indexers ++ encoders
val pipeline = new Pipeline().setStages(stages)
if (doCheckpoint) {
pipeline.setCheckpointInterval(2)
}

val outputDF = pipeline.fit(df).transform(df)

optTempDir.foreach { tempDir =>
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)
}
outputDF
}
val outputWithoutCheckpoint = fitPipeline(false).collect()
val outputWithCheckpoint = fitPipeline(true).collect()
assert(outputWithoutCheckpoint.length == outputWithCheckpoint.length)
outputWithoutCheckpoint.zip(outputWithCheckpoint).foreach { case (noCheckout, checkout) =>
assert(noCheckout === checkout)
}
}
}


Expand Down
4 changes: 4 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,10 @@ class Dataset[T] private[sql](
@InterfaceStability.Evolving
def isStreaming: Boolean = logicalPlan.isStreaming

@Experimental
@InterfaceStability.Evolving
def isCheckpoint: Boolean = queryExecution.toRdd.isCheckpointed

/**
* Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate
* the logical plan of this Dataset, which is especially useful in iterative algorithms where the
Expand Down