From 75b70702643194f331d1891681ef9c0941cce77d Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Mon, 2 Oct 2023 11:36:53 +0200 Subject: [PATCH] [SPARK-45386][SQL]: Fix correctness issue with persist using StorageLevel.NONE on Dataset (#43188) * SPARK-45386: Fix correctness issue with StorageLevel.NONE * Move to CacheManager * Add comment --- .../scala/org/apache/spark/sql/execution/CacheManager.scala | 4 +++- .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 064819275e004..e906c74f8a5ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -113,7 +113,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { planToCache: LogicalPlan, tableName: Option[String], storageLevel: StorageLevel): Unit = { - if (lookupCachedData(planToCache).nonEmpty) { + if (storageLevel == StorageLevel.NONE) { + // Do nothing for StorageLevel.NONE since it will not actually cache any data. + } else if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c967540541a5c..6d9c43f866a0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) @@ -2535,6 +2536,11 @@ class DatasetSuite extends QueryTest checkDataset(ds.filter(f(col("_1"))), Tuple1(ValueClass(2))) } + + test("SPARK-45386: persist with StorageLevel.NONE should give correct count") { + val ds = Seq(1, 2).toDS().persist(StorageLevel.NONE) + assert(ds.count() == 2) + } } class DatasetLargeResultCollectingSuite extends QueryTest