diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index fcf65659c24f..2d7cccbd7016 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -106,6 +106,7 @@ object TreePattern extends Enumeration { val AGGREGATE: Value = Value val AS_OF_JOIN: Value = Value val COMMAND: Value = Value + val COMMAND_RESULT: Value = Value val CTE: Value = Value val DF_DROP_COLUMNS: Value = Value val DISTINCT_LIKE: Value = Value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 189be1d6a30d..0bf90643605e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -271,13 +271,7 @@ class Dataset[T] private[sql]( private[sql] def getRows( numRows: Int, truncate: Int): Seq[Seq[String]] = { - val newDf = logicalPlan match { - case c: CommandResult => - // Convert to `LocalRelation` and let `ConvertToLocalRelation` do the casting locally to - // avoid triggering a job - Dataset.ofRows(sparkSession, LocalRelation(c.output, c.rows)) - case _ => toDF() - } + val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => Column(ToPrettyString(col)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ConvertCommandResultToLocalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ConvertCommandResultToLocalRelation.scala new file mode 100644 index 000000000000..a6511d9bed07 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ConvertCommandResultToLocalRelation.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, IntegerLiteral, InterpretedMutableProjection, Predicate, Unevaluable} +import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, Filter, Limit, LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND_RESULT + +/** + * Converts local operations (i.e. ones that don't require data exchange) on `CommandResult` + * to `LocalRelation`. + */ +object ConvertCommandResultToLocalRelation extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(COMMAND_RESULT)) { + case Project(projectList, CommandResult(output, _, _, rows)) + if !projectList.exists(hasUnevaluableExpr) => + val projection = new InterpretedMutableProjection(projectList, output) + projection.initialize(0) + LocalRelation(projectList.map(_.toAttribute), rows.map(projection(_).copy())) + + case Limit(IntegerLiteral(limit), CommandResult(output, _, _, rows)) => + LocalRelation(output, rows.take(limit)) + + case Filter(condition, CommandResult(output, _, _, rows)) + if !hasUnevaluableExpr(condition) => + val predicate = Predicate.create(condition, output) + predicate.initialize(0) + LocalRelation(output, rows.filter(row => predicate.eval(row))) + } + + private def hasUnevaluableExpr(expr: Expression): Boolean = { + expr.exists(e => e.isInstanceOf[Unevaluable] && !e.isInstanceOf[AttributeReference]) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CommandResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CommandResult.scala index 2ef342227833..cd3e8c7de296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CommandResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CommandResult.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND_RESULT, TreePattern} import org.apache.spark.sql.execution.SparkPlan /** @@ -40,4 +41,6 @@ case class CommandResult( override def computeStats(): Statistics = Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * rows.length) + + override val nodePatterns: Seq[TreePattern] = Seq(COMMAND_RESULT) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 70a35ea91153..2c59d43acbd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods +import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -66,7 +67,14 @@ class SparkOptimizer( CleanupDynamicPruningFilters, // cleanup the unnecessary TrueLiteral predicates BooleanSimplification, - PruneFilters)) ++ + PruneFilters) :+ + Batch("Convert CommandResult to LocalRelation", fixedPoint, + ConvertCommandResultToLocalRelation, + ConvertToLocalRelation, + PropagateEmptyRelation, + // PropagateEmptyRelation can change the nullability of an attribute from nullable to + // non-nullable when an empty relation child of a Union is removed + UpdateAttributeNullability)) ++ postHocOptimizationBatches :+ Batch("Extract Python UDFs", Once, ExtractPythonUDFFromJoinCondition, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index fe295b0cfa26..387b9c713c5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2717,6 +2717,26 @@ class DatasetSuite extends QueryTest checkDataset(ds.map(t => t), WithSet(0, HashSet("foo", "bar")), WithSet(1, HashSet("bar", "zoo"))) } + + test("SPARK-47270: isEmpty does not trigger job execution on CommandResults") { + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") { + withTable("t1", "t2") { + sql("create table t1(c1 int) using parquet") + sql("create table t2(c2 int) using parquet") + + @volatile var jobCounter = 0 + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobCounter += 1 + } + } + withListener(spark.sparkContext, listener) { _ => + sql("show tables").isEmpty + } + assert(jobCounter === 0) + } + } + } } class DatasetLargeResultCollectingSuite extends QueryTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertCommandResultToLocalRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertCommandResultToLocalRelationSuite.scala new file mode 100644 index 000000000000..c8cff23729c4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertCommandResultToLocalRelationSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{LessThan, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ConvertCommandResultToLocalRelationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("ConvertCommandResultToLocalRelation", FixedPoint(100), + ConvertCommandResultToLocalRelation, + ConvertToLocalRelation) :: Nil + } + + test("Project on CommandResult should be turned into a single LocalRelation") { + val testCommandResult = CommandResult( + Seq($"a".int, $"b".int), + null, + null, + InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation($"a1".int, $"b1".int).output, + InternalRow(1, 3) :: InternalRow(4, 6) :: Nil) + + val projectOnLocal = testCommandResult.select( + UnresolvedAttribute("a").as("a1"), + (UnresolvedAttribute("b") + 1).as("b1")) + + val optimized = Optimize.execute(projectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Filter on CommandResult should be turned into a single LocalRelation") { + val testCommandResult = CommandResult( + Seq($"a".int, $"b".int), + null, + null, + InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation($"a1".int, $"b1".int).output, + InternalRow(1, 3) :: Nil) + + val filterAndProjectOnLocal = testCommandResult + .select(UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1")) + .where(LessThan(UnresolvedAttribute("b1"), Literal.create(6))) + + val optimized = Optimize.execute(filterAndProjectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("SPARK-27798: Expression reusing output shouldn't override values in local relation") { + val testCommandResult = CommandResult( + Seq($"a".int), + null, + null, + InternalRow(1) :: InternalRow(2) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation($"a".struct($"a1".int)).output, + InternalRow(InternalRow(1)) :: InternalRow(InternalRow(2)) :: Nil) + + val projected = testCommandResult.select(ExprReuseOutput(UnresolvedAttribute("a")).as("a")) + val optimized = Optimize.execute(projected.analyze) + + comparePlans(optimized, correctAnswer) + } +}