Skip to content

Commit bc1e766

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-22348][SQL] The table cache providing ColumnarBatch should also do partition batch pruning
## What changes were proposed in this pull request? We enable table cache `InMemoryTableScanExec` to provide `ColumnarBatch` now. But the cached batches are retrieved without pruning. In this case, we still need to do partition batch pruning. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #19569 from viirya/SPARK-22348.
1 parent 3f5ba96 commit bc1e766

File tree

2 files changed

+64
-33
lines changed

2 files changed

+64
-33
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ case class InMemoryTableScanExec(
7878

7979
override def inputRDDs(): Seq[RDD[InternalRow]] = {
8080
assert(supportCodegen)
81-
val buffers = relation.cachedColumnBuffers
81+
val buffers = filteredCachedBatches()
8282
// HACK ALERT: This is actually an RDD[ColumnarBatch].
8383
// We're taking advantage of Scala's type erasure here to pass these batches along.
8484
Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]])
@@ -180,19 +180,11 @@ case class InMemoryTableScanExec(
180180

181181
private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
182182

183-
protected override def doExecute(): RDD[InternalRow] = {
184-
val numOutputRows = longMetric("numOutputRows")
185-
186-
if (enableAccumulators) {
187-
readPartitions.setValue(0)
188-
readBatches.setValue(0)
189-
}
190-
183+
private def filteredCachedBatches(): RDD[CachedBatch] = {
191184
// Using these variables here to avoid serialization of entire objects (if referenced directly)
192185
// within the map Partitions closure.
193186
val schema = relation.partitionStatistics.schema
194187
val schemaIndex = schema.zipWithIndex
195-
val relOutput: AttributeSeq = relation.output
196188
val buffers = relation.cachedColumnBuffers
197189

198190
buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
@@ -201,35 +193,49 @@ case class InMemoryTableScanExec(
201193
schema)
202194
partitionFilter.initialize(index)
203195

196+
// Do partition batch pruning if enabled
197+
if (inMemoryPartitionPruningEnabled) {
198+
cachedBatchIterator.filter { cachedBatch =>
199+
if (!partitionFilter.eval(cachedBatch.stats)) {
200+
logDebug {
201+
val statsString = schemaIndex.map { case (a, i) =>
202+
val value = cachedBatch.stats.get(i, a.dataType)
203+
s"${a.name}: $value"
204+
}.mkString(", ")
205+
s"Skipping partition based on stats $statsString"
206+
}
207+
false
208+
} else {
209+
true
210+
}
211+
}
212+
} else {
213+
cachedBatchIterator
214+
}
215+
}
216+
}
217+
218+
protected override def doExecute(): RDD[InternalRow] = {
219+
val numOutputRows = longMetric("numOutputRows")
220+
221+
if (enableAccumulators) {
222+
readPartitions.setValue(0)
223+
readBatches.setValue(0)
224+
}
225+
226+
// Using these variables here to avoid serialization of entire objects (if referenced directly)
227+
// within the map Partitions closure.
228+
val relOutput: AttributeSeq = relation.output
229+
230+
filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator =>
204231
// Find the ordinals and data types of the requested columns.
205232
val (requestedColumnIndices, requestedColumnDataTypes) =
206233
attributes.map { a =>
207234
relOutput.indexOf(a.exprId) -> a.dataType
208235
}.unzip
209236

210-
// Do partition batch pruning if enabled
211-
val cachedBatchesToScan =
212-
if (inMemoryPartitionPruningEnabled) {
213-
cachedBatchIterator.filter { cachedBatch =>
214-
if (!partitionFilter.eval(cachedBatch.stats)) {
215-
logDebug {
216-
val statsString = schemaIndex.map { case (a, i) =>
217-
val value = cachedBatch.stats.get(i, a.dataType)
218-
s"${a.name}: $value"
219-
}.mkString(", ")
220-
s"Skipping partition based on stats $statsString"
221-
}
222-
false
223-
} else {
224-
true
225-
}
226-
}
227-
} else {
228-
cachedBatchIterator
229-
}
230-
231237
// update SQL metrics
232-
val withMetrics = cachedBatchesToScan.map { batch =>
238+
val withMetrics = cachedBatchIterator.map { batch =>
233239
if (enableAccumulators) {
234240
readBatches.add(1)
235241
}

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp}
2323
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
2424
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In}
2525
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
26-
import org.apache.spark.sql.execution.LocalTableScanExec
26+
import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec}
2727
import org.apache.spark.sql.functions._
2828
import org.apache.spark.sql.internal.SQLConf
2929
import org.apache.spark.sql.test.SharedSQLContext
@@ -454,4 +454,29 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
454454
Seq(In(attribute, Nil)), testRelation)
455455
assert(tableScanExec.partitionFilters.isEmpty)
456456
}
457+
458+
test("SPARK-22348: table cache should do partition batch pruning") {
459+
Seq("true", "false").foreach { enabled =>
460+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enabled) {
461+
val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y")
462+
df1.unpersist()
463+
df1.cache()
464+
465+
// Push predicate to the cached table.
466+
val df2 = df1.where("y = 3")
467+
468+
val planBeforeFilter = df2.queryExecution.executedPlan.collect {
469+
case f: FilterExec => f.child
470+
}
471+
assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec])
472+
473+
val execPlan = if (enabled == "true") {
474+
WholeStageCodegenExec(planBeforeFilter.head)
475+
} else {
476+
planBeforeFilter.head
477+
}
478+
assert(execPlan.executeCollectPublic().length == 0)
479+
}
480+
}
481+
}
457482
}

0 commit comments

Comments
 (0)