diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 105623c767d6..0fa8e89bebbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -631,19 +631,26 @@ object ColumnPruning extends Rule[LogicalPlan] { object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p1 @ Project(_, p2: Project) => - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { - p1 - } else { - p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) - } + case p1@Project(_, p2: Project) => + maybeGetCollapsedAndCleanedProjectList(p1, p2.projectList) + .map(cleanedProjectList => p2.copy(projectList = cleanedProjectList)) + .getOrElse(p1) case p @ Project(_, agg: Aggregate) => - if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { - p - } else { - agg.copy(aggregateExpressions = buildCleanedProjectList( - p.projectList, agg.aggregateExpressions)) + maybeGetCollapsedAndCleanedProjectList(p, agg.aggregateExpressions) + .map(cleanedProjectList => agg.copy(aggregateExpressions = cleanedProjectList)) + .getOrElse(p) + } + + private def maybeGetCollapsedAndCleanedProjectList( + upper: Project, + lowerProjectList: Seq[NamedExpression]): Option[Seq[NamedExpression]] = { + if (!haveCommonNonDeterministicOutput(upper.projectList, lowerProjectList)) { + val cleanedProjectList = buildCleanedProjectList(upper.projectList, lowerProjectList) + if (isNumberOfLeafExpressionsBelowLimit(cleanedProjectList)) { + return Option.apply(cleanedProjectList) } + } + Option.empty } private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = { @@ -684,6 +691,18 @@ object CollapseProject extends Rule[LogicalPlan] { CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] } } + + private def isNumberOfLeafExpressionsBelowLimit(projectList: Seq[NamedExpression]): Boolean = { + SQLConf.get.optimizerMaxNumOfLeafExpressionsInCollapsedProject < 0 || + numberOfLeafExpressions(projectList) <= + SQLConf.get.optimizerMaxNumOfLeafExpressionsInCollapsedProject + } + + private def numberOfLeafExpressions(projectList: Seq[Expression]): Long = { + projectList + .map(expr => if (expr.children.nonEmpty) numberOfLeafExpressions(expr.children) else 1) + .sum + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index edc1a488150c..4c2336737dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -143,6 +143,15 @@ object SQLConf { .intConf .createWithDefault(100) + val OPTIMIZER_MAX_NUM_OF_LEAF_EXPRESSIONS_IN_COLLAPSED_PROJECT = + buildConf("spark.sql.optimizer.maxNumOfLeafExpressionsInCollapsedProject") + .internal() + .doc("Sets the maximum number of leaf expressions that a project is allowed to " + + "have after collapsing. If the collapsed project would have more leaf expressions " + + "than this number then the optimizer won't collapse. Set to -1 to disable.") + .longConf + .createWithDefault(10000) + val OPTIMIZER_INSET_CONVERSION_THRESHOLD = buildConf("spark.sql.optimizer.inSetConversionThreshold") .internal() @@ -1477,6 +1486,9 @@ class SQLConf extends Serializable with Logging { def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) + def optimizerMaxNumOfLeafExpressionsInCollapsedProject: Long = + getConf(OPTIMIZER_MAX_NUM_OF_LEAF_EXPRESSIONS_IN_COLLAPSED_PROJECT) + def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index e7a5bcee420f..10381c5f08e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -138,4 +138,16 @@ class CollapseProjectSuite extends PlanTest { assert(projects.size === 1) assert(hasMetadata(optimized)) } + + test("do not collapse if number of leave expressions would be too big") { + var query: LogicalPlan = testRelation + for( a <- 1 to 13) { + // after n iterations the number of leaf expressions will be 2^{n+1} + // => after 13 iterations we would end up with more than 10000 leaf expressions + query = query.select(('a + 'b).as('a), ('a - 'b).as('b)) + } + + val projects = Optimize.execute(query.analyze).collect { case p: Project => p } + assert(projects.size === 2) // everything should be collapsed except the last one + } }