diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1b141572cc7f..236282f07f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -697,16 +697,17 @@ object ColumnPruning extends Rule[LogicalPlan] { * `GlobalLimit(LocalLimit)` pattern is also considered. */ object CollapseProject extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, p2: Project) => - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) + || hasTooManyExprs(p2.projectList)) { p1 } else { p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) } case p @ Project(_, agg: Aggregate) => - if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { + if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions) + || hasTooManyExprs(agg.aggregateExpressions)) { p } else { agg.copy(aggregateExpressions = buildCleanedProjectList( @@ -725,6 +726,14 @@ object CollapseProject extends Rule[LogicalPlan] { s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList))) } + private def hasTooManyExprs(exprs: Seq[Expression]): Boolean = { + if (SQLConf.get.optimizerCollapseProjectExpressionThreshold == -1) false else { + var numExprs = 0 + exprs.foreach { _.foreach { _ => numExprs += 1 } } + numExprs > SQLConf.get.optimizerCollapseProjectExpressionThreshold + } + } + private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = { AttributeMap(projectList.collect { case a: Alias => a.toAttribute -> a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 415ce4678811..7a841c8e1aee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf trait OperationHelper { type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) @@ -122,6 +123,14 @@ object ScanOperation extends OperationHelper with PredicateHelper { }.exists(!_.deterministic)) } + private def hasTooManyExprs(exprs: Seq[Expression]): Boolean = { + if (SQLConf.get.optimizerCollapseProjectExpressionThreshold == -1) false else { + var numExprs = 0 + exprs.foreach { _.foreach { _ => numExprs += 1 } } + numExprs > SQLConf.get.optimizerCollapseProjectExpressionThreshold + } + } + private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { plan match { case Project(fields, child) => @@ -132,7 +141,9 @@ object ScanOperation extends OperationHelper with PredicateHelper { if (!hasCommonNonDeterministic(fields, aliases)) { val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] - Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields))) + if (hasTooManyExprs(substitutedFields)) None else { + Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields))) + } } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9be0497e4660..e918f777c35b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -245,6 +245,15 @@ object SQLConf { .stringConf .createOptional + val OPTIMIZER_COLLAPSE_PROJECT_EXPRESSION_THRESHOLD = + buildConf("spark.sql.optimizer.collapseProjectExpressionThreshold") + .internal() + .doc("Sets a threshold for the size of expressions when collpase project, if the current " + + "project has more expressions than the threshold then the project won't collapse. " + + "Set to -1 to disable.") + .intConf + .createWithDefault(1000) + val DYNAMIC_PARTITION_PRUNING_ENABLED = buildConf("spark.sql.optimizer.dynamicPartitionPruning.enabled") .doc("When true, we will generate predicate for partition column when it's used as join key") @@ -2780,6 +2789,9 @@ class SQLConf extends Serializable with Logging { def optimizerPlanChangeBatches: Option[String] = getConf(OPTIMIZER_PLAN_CHANGE_LOG_BATCHES) + def optimizerCollapseProjectExpressionThreshold: Int = + getConf(OPTIMIZER_COLLAPSE_PROJECT_EXPRESSION_THRESHOLD) + def dynamicPartitionPruningEnabled: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_ENABLED) def dynamicPartitionPruningUseStats: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_USE_STATS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 42bcd13ee378..e49539573b46 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -121,6 +121,17 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("do not collapse project if number of leave expressions would be too big") { + var query: LogicalPlan = testRelation + for( _ <- 1 to 10) { + // after n iterations the number of leaf expressions will be 2^{n+1} + // => after 10 iterations we would end up with more than 1000 leaf expressions + query = query.select(('a + 'b).as('a), ('a - 'b).as('b)) + } + val projects = Optimize.execute(query.analyze).collect { case p: Project => p } + assert(projects.size === 2) // should be collapsed to two projects + } + test("preserve top-level alias metadata while collapsing projects") { def hasMetadata(logicalPlan: LogicalPlan): Boolean = { logicalPlan.asInstanceOf[Project].projectList.exists(_.metadata.contains("key"))