Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -482,6 +483,58 @@ class Dataset[T] private[sql](
@InterfaceStability.Evolving
def isStreaming: Boolean = logicalPlan.isStreaming

/**
* Returns a checkpointed version of this Dataset.
*
* @group basic
* @since 2.1.0
*/
@Experimental
@InterfaceStability.Evolving
def checkpoint(): Dataset[T] = checkpoint(eager = true)

/**
* Returns a checkpointed version of this Dataset.
*
* @param eager When true, materializes the underlying checkpointed RDD eagerly.
*
* @group basic
* @since 2.1.0
*/
@Experimental
@InterfaceStability.Evolving
def checkpoint(eager: Boolean): Dataset[T] = {
val internalRdd = queryExecution.toRdd.map(_.copy())
internalRdd.checkpoint()

if (eager) {
internalRdd.count()
}

val physicalPlan = queryExecution.executedPlan

// Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise the
// size of `PartitioningCollection` may grow exponentially for queries involving deep inner
// joins.
def firstLeafPartitioning(partitioning: Partitioning): Partitioning = {
partitioning match {
case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head)
case p => p
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why we would like to pick the first leaf Partitioning here is that PartitioningCollection, which is also an Expression and participates query planning, may grow exponentially in the benchmark snippet, which essentially builds a full binary tree of Joins.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the partitioning other than the first useful? Can we just filter out the partitioning guaranteed by other partitionings instead of picking the first only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There can be cases where the optimizer fails to eliminate an unnecessary shuffle if we strip all the other partitionings. But that's still better than an exponentially growing PartitioningCollection, which basically runs into the same slow query planning issue this PR tries to solve.

I talked to @yhuai offline about exactly the same issue you brought up before sending out this PR, and we decided to have a working version first and optimize it later since we still need feedback from ML people to see whether the basic mechanism works for their workloads.


val outputPartitioning = firstLeafPartitioning(physicalPlan.outputPartitioning)

Dataset.ofRows(
sparkSession,
LogicalRDD(
logicalPlan.output,
internalRdd,
outputPartitioning,
physicalPlan.outputOrdering
)(sparkSession)).as[T]
}

/**
* Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated,
* and all cells will be aligned right. For example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -130,17 +130,40 @@ case class ExternalRDDScanExec[T](
/** Logical plan node for scanning data from an RDD of InternalRow. */
case class LogicalRDD(
output: Seq[Attribute],
rdd: RDD[InternalRow])(session: SparkSession)
rdd: RDD[InternalRow],
outputPartitioning: Partitioning = UnknownPartitioning(0),
outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession)
extends LeafNode with MultiInstanceRelation {

override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil

override def newInstance(): LogicalRDD.this.type =
LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type]
override def newInstance(): LogicalRDD.this.type = {
val rewrite = output.zip(output.map(_.newInstance())).toMap

val rewrittenPartitioning = outputPartitioning match {
case p: Expression =>
p.transform {
case e: Attribute => rewrite.getOrElse(e, e)
}.asInstanceOf[Partitioning]

case p => p
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all Partitioning classes are Expressions, while we only need to rewrite attributes within those Partitionings that are also Expressions.

}

val rewrittenOrdering = outputOrdering.map(_.transform {
case e: Attribute => rewrite.getOrElse(e, e)
}.asInstanceOf[SortOrder])

LogicalRDD(
output.map(rewrite),
rdd,
rewrittenPartitioning,
rewrittenOrdering
)(session).asInstanceOf[this.type]
}

override def sameResult(plan: LogicalPlan): Boolean = {
plan.canonicalized match {
case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id
case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id
case _ => false
}
}
Expand All @@ -158,7 +181,9 @@ case class LogicalRDD(
case class RDDScanExec(
output: Seq[Attribute],
rdd: RDD[InternalRow],
override val nodeName: String) extends LeafExecNode {
override val nodeName: String,
override val outputPartitioning: Partitioning = UnknownPartitioning(0),
override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQuery

/**
* Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting
Expand Down Expand Up @@ -402,13 +400,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
case logical.OneRowRelation =>
execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
case r : logical.Range =>
case r: logical.Range =>
execution.RangeExec(r) :: Nil
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case LogicalRDD(output, rdd) => RDDScanExec(output, rdd, "ExistingRDD") :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
}
Expand Down
68 changes: 68 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

Expand Down Expand Up @@ -919,6 +922,71 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
df.withColumn("b", expr("0")).as[ClassData]
.groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
}

Seq(true, false).foreach { eager =>
def testCheckpointing(testName: String)(f: => Unit): Unit = {
test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
withTempDir { dir =>
val originalCheckpointDir = spark.sparkContext.checkpointDir

try {
spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
f
} finally {
// Since the original checkpointDir can be None, we need
// to set the variable directly.
spark.sparkContext.checkpointDir = originalCheckpointDir
}
}
}
}

testCheckpointing("basic") {
val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc)
val cp = ds.checkpoint(eager)

val logicalRDD = cp.logicalPlan match {
case plan: LogicalRDD => plan
case _ =>
val treeString = cp.logicalPlan.treeString(verbose = true)
fail(s"Expecting a LogicalRDD, but got\n$treeString")
}

val dsPhysicalPlan = ds.queryExecution.executedPlan
val cpPhysicalPlan = cp.queryExecution.executedPlan

assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning }
assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering }

assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning }
assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering }

// For a lazy checkpoint() call, the first check also materializes the checkpoint.
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)

// Reads back from checkpointed data and check again.
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
}

testCheckpointing("should preserve partitioning information") {
val ds = spark.range(10).repartition('id % 2)
val cp = ds.checkpoint(eager)

val agg = cp.groupBy('id % 2).agg(count('id))

agg.queryExecution.executedPlan.collectFirst {
case ShuffleExchange(_, _: RDDScanExec, _) =>
case BroadcastExchangeExec(_, _: RDDScanExec) =>
}.foreach { _ =>
fail(
"No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " +
"preserves partitioning information:\n\n" + agg.queryExecution
)
}

checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
}
}
}

case class Generic[T](id: T, value: Double)
Expand Down