diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index c34f65234a48..9f02d0e3d0e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -81,7 +81,7 @@ class CacheManager extends Logging { } private def extractStatsOfPlanForCache(plan: LogicalPlan): Option[Statistics] = { - if (plan.conf.cboEnabled && plan.stats.rowCount.isDefined) { + if (plan.stats.rowCount.isDefined) { Some(plan.stats) } else { None @@ -156,7 +156,7 @@ class CacheManager extends Logging { storageLevel = cd.cachedRepresentation.storageLevel, child = spark.sessionState.executePlan(cd.plan).executedPlan, tableName = cd.cachedRepresentation.tableName, - stats = extractStatsOfPlanForCache(cd.plan)) + statsOfPlanToCache = extractStatsOfPlanForCache(cd.plan)) needToRecache += cd.copy(cachedRepresentation = newCache) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 6750599a9bb0..48787d0f8fd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -38,9 +38,9 @@ object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String], - stats: Option[Statistics]): InMemoryRelation = + statsOfPlanToCache: Option[Statistics]): InMemoryRelation = new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = stats) + statsOfPlanToCache = statsOfPlanToCache) } @@ -74,6 +74,8 @@ case class InMemoryRelation( override def computeStats(): Statistics = { if (batchStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache when + // applicable statsOfPlanToCache.getOrElse(Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes)) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 3a08d77f6b98..a5925e317260 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.Utils class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -480,4 +481,32 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-22673: InMemoryRelation should utilize existing stats whenever possible") { + withSQLConf("spark.sql.cbo.enabled" -> "true") { + // scalastyle:off + val workDir = s"${Utils.createTempDir()}/table1" + val data = Seq(100, 200, 300, 400).toDF("count") + data.write.parquet(workDir) + val dfFromFile = spark.read.parquet(workDir).cache() + val inMemoryRelation = dfFromFile.queryExecution.optimizedPlan.collect { + case plan: InMemoryRelation => plan + }.head + // InMemoryRelation's stats is Long.MaxValue before the underlying RDD is materialized + assert(inMemoryRelation.computeStats().sizeInBytes === Long.MaxValue) + // InMemoryRelation's stats is updated after materializing RDD + dfFromFile.collect() + assert(inMemoryRelation.computeStats().sizeInBytes === 16) + // test of catalog table + val dfFromTable = spark.catalog.createTable("table1", workDir).cache() + val inMemoryRelation2 = dfFromTable.queryExecution.optimizedPlan. + collect { case plan: InMemoryRelation => plan }.head + // Even CBO enabled, InMemoryRelation's stats keeps as the default one before table's stats + // is calculated + assert(inMemoryRelation2.computeStats().sizeInBytes === Long.MaxValue) + // InMemoryRelation's stats should be updated after calculating stats of the table + spark.sql("ANALYZE TABLE table1 COMPUTE STATISTICS") + assert(inMemoryRelation2.computeStats().sizeInBytes === 16) + } + } }