Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ case class InMemoryTableScanExec(

override def inputRDDs(): Seq[RDD[InternalRow]] = {
assert(supportCodegen)
val buffers = relation.cachedColumnBuffers
val buffers = filteredCachedBatches()
// HACK ALERT: This is actually an RDD[ColumnarBatch].
// We're taking advantage of Scala's type erasure here to pass these batches along.
Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]])
Expand Down Expand Up @@ -180,19 +180,11 @@ case class InMemoryTableScanExec(

private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

if (enableAccumulators) {
readPartitions.setValue(0)
readBatches.setValue(0)
}

private def filteredCachedBatches(): RDD[CachedBatch] = {
// Using these variables here to avoid serialization of entire objects (if referenced directly)
// within the map Partitions closure.
val schema = relation.partitionStatistics.schema
val schemaIndex = schema.zipWithIndex
val relOutput: AttributeSeq = relation.output
val buffers = relation.cachedColumnBuffers

buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
Expand All @@ -201,35 +193,49 @@ case class InMemoryTableScanExec(
schema)
partitionFilter.initialize(index)

// Do partition batch pruning if enabled
if (inMemoryPartitionPruningEnabled) {
cachedBatchIterator.filter { cachedBatch =>
if (!partitionFilter.eval(cachedBatch.stats)) {
logDebug {
val statsString = schemaIndex.map { case (a, i) =>
val value = cachedBatch.stats.get(i, a.dataType)
s"${a.name}: $value"
}.mkString(", ")
s"Skipping partition based on stats $statsString"
}
false
} else {
true
}
}
} else {
cachedBatchIterator
}
}
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

if (enableAccumulators) {
readPartitions.setValue(0)
readBatches.setValue(0)
}

// Using these variables here to avoid serialization of entire objects (if referenced directly)
// within the map Partitions closure.
val relOutput: AttributeSeq = relation.output

filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator =>
// Find the ordinals and data types of the requested columns.
val (requestedColumnIndices, requestedColumnDataTypes) =
attributes.map { a =>
relOutput.indexOf(a.exprId) -> a.dataType
}.unzip

// Do partition batch pruning if enabled
val cachedBatchesToScan =
if (inMemoryPartitionPruningEnabled) {
cachedBatchIterator.filter { cachedBatch =>
if (!partitionFilter.eval(cachedBatch.stats)) {
logDebug {
val statsString = schemaIndex.map { case (a, i) =>
val value = cachedBatch.stats.get(i, a.dataType)
s"${a.name}: $value"
}.mkString(", ")
s"Skipping partition based on stats $statsString"
}
false
} else {
true
}
}
} else {
cachedBatchIterator
}

// update SQL metrics
val withMetrics = cachedBatchesToScan.map { batch =>
val withMetrics = cachedBatchIterator.map { batch =>
if (enableAccumulators) {
readBatches.add(1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -454,4 +454,29 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
Seq(In(attribute, Nil)), testRelation)
assert(tableScanExec.partitionFilters.isEmpty)
}

test("SPARK-22348: table cache should do partition batch pruning") {
Seq("true", "false").foreach { enabled =>
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enabled) {
val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y")
df1.unpersist()
df1.cache()

// Push predicate to the cached table.
val df2 = df1.where("y = 3")

val planBeforeFilter = df2.queryExecution.executedPlan.collect {
case f: FilterExec => f.child
}
assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec])

val execPlan = if (enabled == "true") {
WholeStageCodegenExec(planBeforeFilter.head)
} else {
planBeforeFilter.head
}
assert(execPlan.executeCollectPublic().length == 0)
}
}
}
}