diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index fceb9db411200..1b2e802ae9395 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -587,8 +587,10 @@ case class AdaptiveSparkPlanExec( BroadcastQueryStageExec(currentStageId, newPlan, e.canonicalized) } case i: InMemoryTableScanExec => - // No need to optimize `InMemoryTableScanExec` as it's a leaf node. - TableCacheQueryStageExec(currentStageId, i) + // Apply `queryStageOptimizerRules` so that we can reuse subquery. + // No need to apply `postStageCreationRules` for `InMemoryTableScanExec` + // as it's a leaf node. + TableCacheQueryStageExec(currentStageId, optimizeQueryStage(i, isFinalStage = false)) } currentStageId += 1 setLogicalLinkForNewQueryStage(queryStage, plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 947a7314142fe..1f05adc57a4bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -125,7 +125,7 @@ case class InsertAdaptiveSparkPlan( /** * Returns an expression-id-to-execution-plan map for all the sub-queries. * For each sub-query, generate the adaptive execution plan for each sub-query by applying this - * rule, or reuse the execution plan from another sub-query of the same semantics if possible. + * rule. */ private def buildSubqueryMap(plan: SparkPlan): Map[Long, BaseSubqueryExec] = { val subqueryMap = mutable.HashMap.empty[Long, BaseSubqueryExec] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReuseAdaptiveSubquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReuseAdaptiveSubquery.scala index c1d0e93e3b979..df6849447215d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReuseAdaptiveSubquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReuseAdaptiveSubquery.scala @@ -33,11 +33,16 @@ case class ReuseAdaptiveSubquery( plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case sub: ExecSubqueryExpression => - val newPlan = reuseMap.getOrElseUpdate(sub.plan.canonicalized, sub.plan) - if (newPlan.ne(sub.plan)) { - sub.withNewPlan(ReusedSubqueryExec(newPlan)) - } else { - sub + // The subquery can be already reused (the same Java object) due to filter pushdown + // of table cache. If it happens, we just need to wrap the current subquery with + // `ReusedSubqueryExec` and no need to update the `reuseMap`. + reuseMap.get(sub.plan.canonicalized).map { subquery => + sub.withNewPlan(ReusedSubqueryExec(subquery)) + }.getOrElse { + reuseMap.putIfAbsent(sub.plan.canonicalized, sub.plan) match { + case Some(subquery) => sub.withNewPlan(ReusedSubqueryExec(subquery)) + case None => sub + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 5548108b91508..1f2235a10a9ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join, JoinStrategyHint, SHUFFLE_HASH} import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ @@ -823,21 +823,33 @@ class CachedTableSuite extends QueryTest with SQLTestUtils test("SPARK-19993 subquery with cached underlying relation") { withTempView("t1") { - Seq(1).toDF("c1").createOrReplaceTempView("t1") - spark.catalog.cacheTable("t1") - - // underlying table t1 is cached as well as the query that refers to it. - val sqlText = - """ - |SELECT * FROM t1 - |WHERE - |NOT EXISTS (SELECT * FROM t1) - """.stripMargin - val ds = sql(sqlText) - assert(getNumInMemoryRelations(ds) == 2) - - val cachedDs = sql(sqlText).cache() - assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3) + Seq(false, true).foreach { enabled => + withSQLConf( + SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> enabled.toString, + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + AQEPropagateEmptyRelation.ruleName) { + + Seq(1).toDF("c1").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") + + // underlying table t1 is cached as well as the query that refers to it. + val sqlText = + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin + val ds = sql(sqlText) + assert(getNumInMemoryRelations(ds) == 2) + + val cachedDs = sql(sqlText).cache() + cachedDs.collect() + assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.executedPlan) == 3) + + cachedDs.unpersist() + spark.catalog.uncacheTable("t1") + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 32d913ca3b425..2425854e3c8f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2337,15 +2337,9 @@ class SubquerySuite extends QueryTest case rs: ReusedSubqueryExec => rs.child.id } - if (enableAQE) { - assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 4, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } else { - assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 5, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 5, + "Missing or unexpected reused ReusedSubqueryExec in the plan") } } } @@ -2413,15 +2407,9 @@ class SubquerySuite extends QueryTest case rs: ReusedSubqueryExec => rs.child.id } - if (enableAQE) { - assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 3, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } else { - assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 4, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 4, + "Missing or unexpected reused ReusedSubqueryExec in the plan") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 7d0879c21d5fb..58936f5d8dc82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2826,6 +2826,21 @@ class AdaptiveQueryExecSuite .executedPlan.isInstanceOf[LocalTableScanExec]) } } + + test("SPARK-43376: Improve reuse subquery with table cache") { + withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") { + withTable("t1", "t2") { + withCache("t1") { + Seq(1).toDF("c1").cache().createOrReplaceTempView("t1") + Seq(2).toDF("c2").createOrReplaceTempView("t2") + + val (_, adaptive) = runAdaptiveAndVerifyResult( + "SELECT * FROM t1 WHERE c1 < (SELECT c2 FROM t2)") + assert(findReusedSubquery(adaptive).size == 1) + } + } + } + } } /**