Skip to content

Commit 609fba7

Browse files
committed
Add Dataset.checkpoint() to truncate large query plans
Restore checkpoint directory at the end of the test case Add eager argument to Dataset.checkpoint() Address PR comments
1 parent 2c7394a commit 609fba7

File tree

4 files changed

+146
-12
lines changed

4 files changed

+146
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
4040
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
4141
import org.apache.spark.sql.catalyst.plans._
4242
import org.apache.spark.sql.catalyst.plans.logical._
43+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
4344
import org.apache.spark.sql.catalyst.util.usePrettyExpression
4445
import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution}
4546
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView}
46-
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
47+
import org.apache.spark.sql.execution.datasources.LogicalRelation
4748
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
4849
import org.apache.spark.sql.execution.python.EvaluatePython
49-
import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery}
50+
import org.apache.spark.sql.streaming.DataStreamWriter
5051
import org.apache.spark.sql.types._
5152
import org.apache.spark.storage.StorageLevel
5253
import org.apache.spark.util.Utils
@@ -482,6 +483,48 @@ class Dataset[T] private[sql](
482483
@InterfaceStability.Evolving
483484
def isStreaming: Boolean = logicalPlan.isStreaming
484485

486+
/**
487+
* Returns a checkpointed version of this Dataset.
488+
*
489+
* @param eager When true, materializes the underlying checkpointed RDD eagerly.
490+
*
491+
* @group basic
492+
* @since 2.1.0
493+
*/
494+
@Experimental
495+
@InterfaceStability.Evolving
496+
def checkpoint(eager: Boolean = false): Dataset[T] = {
497+
val internalRdd = queryExecution.toRdd.map(_.copy())
498+
internalRdd.checkpoint()
499+
500+
if (eager) {
501+
internalRdd.count()
502+
}
503+
504+
val physicalPlan = queryExecution.executedPlan
505+
506+
// Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise the
507+
// size of `PartitioningCollection` may grow exponentially for queries involving deep inner
508+
// joins.
509+
def firstLeafPartitioning(partitioning: Partitioning): Partitioning = {
510+
partitioning match {
511+
case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head)
512+
case p => p
513+
}
514+
}
515+
516+
val outputPartitioning = firstLeafPartitioning(physicalPlan.outputPartitioning)
517+
518+
Dataset.ofRows(
519+
sparkSession,
520+
LogicalRDD(
521+
logicalPlan.output,
522+
internalRdd,
523+
outputPartitioning,
524+
physicalPlan.outputOrdering
525+
)(sparkSession)).as[T]
526+
}
527+
485528
/**
486529
* Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated,
487530
* and all cells will be aligned right. For example:

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2323
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.plans.logical._
26-
import org.apache.spark.sql.execution.datasources._
26+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
2727
import org.apache.spark.sql.execution.metric.SQLMetrics
2828
import org.apache.spark.sql.types.DataType
2929
import org.apache.spark.util.Utils
@@ -130,17 +130,40 @@ case class ExternalRDDScanExec[T](
130130
/** Logical plan node for scanning data from an RDD of InternalRow. */
131131
case class LogicalRDD(
132132
output: Seq[Attribute],
133-
rdd: RDD[InternalRow])(session: SparkSession)
133+
rdd: RDD[InternalRow],
134+
outputPartitioning: Partitioning = UnknownPartitioning(0),
135+
outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession)
134136
extends LeafNode with MultiInstanceRelation {
135137

136138
override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil
137139

138-
override def newInstance(): LogicalRDD.this.type =
139-
LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type]
140+
override def newInstance(): LogicalRDD.this.type = {
141+
val rewrite = output.zip(output.map(_.newInstance())).toMap
142+
143+
val rewrittenPartitioning = outputPartitioning match {
144+
case p: Expression =>
145+
p.transform {
146+
case e: Attribute => rewrite.getOrElse(e, e)
147+
}.asInstanceOf[Partitioning]
148+
149+
case p => p
150+
}
151+
152+
val rewrittenOrdering = outputOrdering.map(_.transform {
153+
case e: Attribute => rewrite.getOrElse(e, e)
154+
}.asInstanceOf[SortOrder])
155+
156+
LogicalRDD(
157+
output.map(rewrite),
158+
rdd,
159+
rewrittenPartitioning,
160+
rewrittenOrdering
161+
)(session).asInstanceOf[this.type]
162+
}
140163

141164
override def sameResult(plan: LogicalPlan): Boolean = {
142165
plan.canonicalized match {
143-
case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id
166+
case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id
144167
case _ => false
145168
}
146169
}
@@ -158,7 +181,9 @@ case class LogicalRDD(
158181
case class RDDScanExec(
159182
output: Seq[Attribute],
160183
rdd: RDD[InternalRow],
161-
override val nodeName: String) extends LeafExecNode {
184+
override val nodeName: String,
185+
override val outputPartitioning: Partitioning = UnknownPartitioning(0),
186+
override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode {
162187

163188
override lazy val metrics = Map(
164189
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ import org.apache.spark.sql.execution.datasources._
3232
import org.apache.spark.sql.execution.exchange.ShuffleExchange
3333
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
3434
import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec}
35-
import org.apache.spark.sql.internal.SQLConf
36-
import org.apache.spark.sql.streaming.StreamingQuery
3735

3836
/**
3937
* Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting
@@ -402,13 +400,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
402400
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
403401
case logical.OneRowRelation =>
404402
execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
405-
case r : logical.Range =>
403+
case r: logical.Range =>
406404
execution.RangeExec(r) :: Nil
407405
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
408406
exchange.ShuffleExchange(HashPartitioning(
409407
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
410408
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
411-
case LogicalRDD(output, rdd) => RDDScanExec(output, rdd, "ExistingRDD") :: Nil
409+
case r: LogicalRDD =>
410+
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
412411
case BroadcastHint(child) => planLater(child) :: Nil
413412
case _ => Nil
414413
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@ import java.sql.{Date, Timestamp}
2222

2323
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
2424
import org.apache.spark.sql.catalyst.util.sideBySide
25+
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
26+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
2527
import org.apache.spark.sql.execution.streaming.MemoryStream
2628
import org.apache.spark.sql.functions._
29+
import org.apache.spark.sql.internal.SQLConf
2730
import org.apache.spark.sql.test.SharedSQLContext
2831
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
2932

@@ -919,6 +922,70 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
919922
df.withColumn("b", expr("0")).as[ClassData]
920923
.groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
921924
}
925+
926+
Seq(true, false).foreach { eager =>
927+
def testCheckpointing(testName: String)(f: => Unit): Unit = {
928+
test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
929+
withTempDir { dir =>
930+
val originalCheckpointDir = spark.sparkContext.checkpointDir
931+
932+
try {
933+
spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
934+
f
935+
spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
936+
} finally {
937+
spark.sparkContext.checkpointDir = originalCheckpointDir
938+
}
939+
}
940+
}
941+
}
942+
943+
testCheckpointing("basic") {
944+
val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc)
945+
val cp = ds.checkpoint(eager)
946+
947+
val logicalRDD = cp.logicalPlan match {
948+
case plan: LogicalRDD => plan
949+
case _ =>
950+
val treeString = cp.logicalPlan.treeString(verbose = true)
951+
fail(s"Expecting a LogicalRDD, but got\n$treeString")
952+
}
953+
954+
val dsPhysicalPlan = ds.queryExecution.executedPlan
955+
val cpPhysicalPlan = cp.queryExecution.executedPlan
956+
957+
assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning }
958+
assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering }
959+
960+
assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning }
961+
assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering }
962+
963+
// For a lazy checkpoint() call, the first check also materializes the checkpoint.
964+
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
965+
966+
// Reads back from checkpointed data and check again.
967+
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
968+
}
969+
970+
testCheckpointing("should preserve partitioning information") {
971+
val ds = spark.range(10).repartition('id % 2)
972+
val cp = ds.checkpoint(eager)
973+
974+
val agg = cp.groupBy('id % 2).agg(count('id))
975+
976+
agg.queryExecution.executedPlan.collectFirst {
977+
case ShuffleExchange(_, _: RDDScanExec, _) =>
978+
case BroadcastExchangeExec(_, _: RDDScanExec) =>
979+
}.foreach { _ =>
980+
fail(
981+
"No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " +
982+
"preserves partitioning information:\n\n" + agg.queryExecution
983+
)
984+
}
985+
986+
checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
987+
}
988+
}
922989
}
923990

924991
case class Generic[T](id: T, value: Double)

0 commit comments

Comments
 (0)