diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9519a56c2817..231a8a1f190a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -214,7 +214,8 @@ abstract class Optimizer(catalogManager: CatalogManager) // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ - Batch("RewriteSubquery", Once, + // `CollapseProject` cannot collapse all projects in once. So we need `fixedPoint` here. + Batch("RewriteSubquery", fixedPoint, RewritePredicateSubquery, ColumnPruning, CollapseProject, @@ -724,20 +725,17 @@ object ColumnPruning extends Rule[LogicalPlan] { /** * Combines two [[Project]] operators into one and perform alias substitution, * merging the expressions into one single expression for the following cases. - * 1. When two [[Project]] operators are adjacent. + * 1. When two [[Project]] operators are adjacent, if the number of common expressions in the + * combined [[Project]] is not more than `spark.sql.optimizer.maxCommonExprsInCollapseProject`. * 2. When two [[Project]] operators have LocalLimit/Sample/Repartition operator between them * and the upper project consists of the same number of columns which is equal or aliasing. * `GlobalLimit(LocalLimit)` pattern is also considered. */ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p1 @ Project(_, p2: Project) => - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { - p1 - } else { - p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) - } + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case p @ Project(_, _: Project) => + collapseProjects(p) case p @ Project(_, agg: Aggregate) => if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { p @@ -758,6 +756,42 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList))) } + private def collapseProjects(plan: LogicalPlan): LogicalPlan = plan match { + case p1 @ Project(_, p2: Project) => + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || + moreThanMaxAllowedCommonOutput(p1.projectList, p2.projectList)) { + p1 + } else { + collapseProjects( + p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))) + } + case _ => plan + } + + private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = { + AttributeMap(projectList.collect { + case a: Alias => a.toAttribute -> a + }) + } + + // Whether the largest times common outputs from lower operator used in upper operators is + // larger than allowed. + private def moreThanMaxAllowedCommonOutput( + upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { + val aliases = collectAliases(lower) + val exprMap = mutable.HashMap.empty[Attribute, Int] + + upper.foreach(_.collect { + case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1) + }) + + if (exprMap.nonEmpty) { + exprMap.maxBy(_._2)._2 > SQLConf.get.maxCommonExprsInCollapseProject + } else { + false + } + } + private def haveCommonNonDeterministicOutput( upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { val aliases = getAliasMap(lower) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 2880e87ab156..36ccea341a8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.planning +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ @@ -124,14 +126,32 @@ object ScanOperation extends OperationHelper with PredicateHelper { }.exists(!_.deterministic)) } + private def moreThanMaxAllowedCommonOutput( + expr: Seq[NamedExpression], + aliases: AttributeMap[Expression]): Boolean = { + val exprMap = mutable.HashMap.empty[Attribute, Int] + + expr.foreach(_.collect { + case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1) + }) + + if (exprMap.nonEmpty) { + exprMap.maxBy(_._2)._2 > SQLConf.get.maxCommonExprsInCollapseProject + } else { + false + } + } + private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { plan match { case Project(fields, child) => collectProjectsAndFilters(child) match { case Some((_, filters, other, aliases)) => // Follow CollapseProject and only keep going if the collected Projects - // do not have common non-deterministic expressions. - if (!hasCommonNonDeterministic(fields, aliases)) { + // do not have common non-deterministic expressions, and do not have more than + // maximum allowed common outputs. + if (!hasCommonNonDeterministic(fields, aliases) && + !moreThanMaxAllowedCommonOutput(fields, aliases)) { val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8825f4f96378..752fdcd50688 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1963,6 +1963,27 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT = + buildConf("spark.sql.optimizer.maxCommonExprsInCollapseProject") + .doc("An integer number indicates the maximum allowed number of common input expression " + + "from lower Project when being collapsed into upper Project by optimizer rule " + + "`CollapseProject`. Normally `CollapseProject` will collapse adjacent Project " + + "and merge expressions. But in some edge cases, expensive expressions might be " + + "duplicated many times in merged Project by this optimization. This config sets " + + "a maximum number. Once an expression is duplicated more than this number " + + "if merging two Project, Spark SQL will skip the merging. Note that normally " + + "in whole-stage codegen Project operator will de-duplicate expressions internally, " + + "but in edge cases Spark cannot do whole-stage codegen and fallback to interpreted " + + "mode. In such cases, users can use this config to avoid duplicate expressions. " + + "Note that even users exclude `CollapseProject` rule using " + + "`spark.sql.optimizer.excludedRules`, at physical planning phase Spark will still " + + "collapse projections. This config is also effective on collapsing projections in " + + "the physical planning.") + .version("3.1.0") + .intConf + .checkValue(_ > 0, "The value of maxCommonExprsInCollapseProject must be larger than zero.") + .createWithDefault(Int.MaxValue) + val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = buildConf("spark.sql.decimalOperations.allowPrecisionLoss") .internal() @@ -3405,6 +3426,8 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def maxCommonExprsInCollapseProject: Int = getConf(MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT) + def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 42bcd13ee378..1a57731700dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Rand} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{MetadataBuilder, StructType} class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -170,4 +171,59 @@ class CollapseProjectSuite extends PlanTest { val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze comparePlans(optimized, expected) } + + test("SPARK-32945: avoid collapsing projects if reaching max allowed common exprs") { + val options = Map.empty[String, String] + val schema = StructType.fromDDL("a int, b int, c string, d long") + + Seq("1", "2", "3", "4").foreach { maxCommonExprs => + withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) { + // If we collapse two Projects, `JsonToStructs` will be repeated three times. + val relation = LocalRelation('json.string) + val query1 = relation.select( + JsonToStructs(schema, options, 'json).as("struct")) + .select( + GetStructField('struct, 0).as("a"), + GetStructField('struct, 1).as("b"), + GetStructField('struct, 2).as("c"), + GetStructField('struct, 3).as("d")).analyze + val optimized1 = Optimize.execute(query1) + + val query2 = relation + .select('json, JsonToStructs(schema, options, 'json).as("struct")) + .select('json, 'struct, GetStructField('struct, 0).as("a")) + .select('json, 'struct, 'a, GetStructField('struct, 1).as("b")) + .select('json, 'struct, 'a, 'b, GetStructField('struct, 2).as("c")) + .analyze + val optimized2 = Optimize.execute(query2) + + if (maxCommonExprs.toInt < 4) { + val expected1 = query1 + comparePlans(optimized1, expected1) + + val expected2 = relation + .select('json, JsonToStructs(schema, options, 'json).as("struct")) + .select('json, 'struct, + GetStructField('struct, 0).as("a"), + GetStructField('struct, 1).as("b"), + GetStructField('struct, 2).as("c")) + .analyze + comparePlans(optimized2, expected2) + } else { + val expected1 = relation.select( + GetStructField(JsonToStructs(schema, options, 'json), 0).as("a"), + GetStructField(JsonToStructs(schema, options, 'json), 1).as("b"), + GetStructField(JsonToStructs(schema, options, 'json), 2).as("c"), + GetStructField(JsonToStructs(schema, options, 'json), 3).as("d")).analyze + comparePlans(optimized1, expected1) + + val expected2 = relation.select('json, JsonToStructs(schema, options, 'json).as("struct"), + GetStructField(JsonToStructs(schema, options, 'json), 0).as("a"), + GetStructField(JsonToStructs(schema, options, 'json), 1).as("b"), + GetStructField(JsonToStructs(schema, options, 'json), 2).as("c")).analyze + comparePlans(optimized2, expected2) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 321f4966178d..cccc44094c98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -32,12 +32,13 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.Uuid +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} @@ -2567,6 +2568,51 @@ class DataFrameSuite extends QueryTest val df = l.join(r, $"col2" === $"col4", "LeftOuter") checkAnswer(df, Row("2", "2")) } + + test("SPARK-32945: Avoid collapsing projects if reaching max allowed common exprs") { + val options = Map.empty[String, String] + val schema = StructType.fromDDL("a int, b int, c long, d string") + + withTable("test_table") { + val jsonDf = Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""").toDF("json") + jsonDf.write.saveAsTable("test_table") + + Seq("1", "2", "3", "4").foreach { maxCommonExprs => + withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) { + + val jsonDf = spark.read.table("test_table") + val jsonStruct = UnresolvedAttribute("struct") + val df = jsonDf + .select(from_json('json, schema, options).as("struct")) + .select( + Column(GetStructField(jsonStruct, 0)).as("a"), + Column(GetStructField(jsonStruct, 1)).as("b"), + Column(GetStructField(jsonStruct, 2)).as("c"), + Column(GetStructField(jsonStruct, 3)).as("d")) + + val numProjects = df.queryExecution.executedPlan.collect { + case p: ProjectExec => p + }.size + + val numFromJson = df.queryExecution.executedPlan.collect { + case p: ProjectExec => p.projectList.flatMap(_.collect { + case j: JsonToStructs => j + }) + }.flatten.size + + if (maxCommonExprs.toInt < 4) { + assert(numProjects == 2) + assert(numFromJson == 1) + } else { + assert(numProjects == 1) + assert(numFromJson == 4) + } + + checkAnswer(df, Row(1, 2, 123L, "test")) + } + } + } + } } case class GroupByKey(a: Int, b: Int)