diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 0ea806d6cb50..700c7986dea9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -114,13 +114,13 @@ class CacheManager extends Logging { } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the cache entry that refers to the given plan. */ def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { val it = cachedData.iterator() while (it.hasNext) { val cd = it.next() - if (cd.plan.find(_.sameResult(plan)).isDefined) { + if (plan.sameResult(cd.plan)) { cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) it.remove() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a..40f07e37fe0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -43,6 +43,10 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { // joined Dataset should not be persisted val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") assert(joined.storageLevel == StorageLevel.NONE) + // cleanup + ds1.unpersist() + assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.") + } test("persist and unpersist") { @@ -58,7 +62,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { 2, 3, 4) // Drop the cache. cached.unpersist() - assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not be cached.") + assert(cached.storageLevel == StorageLevel.NONE, "The Dataset cached should not be cached.") } test("persist and then rebind right encoder when join 2 datasets") { @@ -80,6 +84,24 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.") } + test("SPARK-21478: persist parent and child Dataset and unpersist parent Dataset") { + val ds1 = Seq(1).toDF() + ds1.persist() + ds1.count() + assert(ds1.storageLevel.useMemory) + + val ds2 = ds1.select($"value" * 2) + ds2.persist() + ds2.count() + assert(ds2.storageLevel.useMemory) + + ds1.unpersist() + assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.") + assert(ds2.storageLevel.useMemory, "The Dataset ds2 should be cached.") + ds2.unpersist() + assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.") + } + test("persist and then groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupByKey(_._1)