Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -631,19 +631,26 @@ object ColumnPruning extends Rule[LogicalPlan] {
object CollapseProject extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, p2: Project) =>
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
p1
} else {
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
}
case p1@Project(_, p2: Project) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we need to change this line? We can keep this line as is.

maybeGetCollapsedAndCleanedProjectList(p1, p2.projectList)
.map(cleanedProjectList => p2.copy(projectList = cleanedProjectList))
.getOrElse(p1)
case p @ Project(_, agg: Aggregate) =>
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
p
} else {
agg.copy(aggregateExpressions = buildCleanedProjectList(
p.projectList, agg.aggregateExpressions))
maybeGetCollapsedAndCleanedProjectList(p, agg.aggregateExpressions)
.map(cleanedProjectList => agg.copy(aggregateExpressions = cleanedProjectList))
.getOrElse(p)
}

private def maybeGetCollapsedAndCleanedProjectList(
upper: Project,
lowerProjectList: Seq[NamedExpression]): Option[Seq[NamedExpression]] = {
if (!haveCommonNonDeterministicOutput(upper.projectList, lowerProjectList)) {
val cleanedProjectList = buildCleanedProjectList(upper.projectList, lowerProjectList)
if (isNumberOfLeafExpressionsBelowLimit(cleanedProjectList)) {
return Option.apply(cleanedProjectList)
}
}
Option.empty
}

private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
Expand Down Expand Up @@ -684,6 +691,18 @@ object CollapseProject extends Rule[LogicalPlan] {
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
}
}

private def isNumberOfLeafExpressionsBelowLimit(projectList: Seq[NamedExpression]): Boolean = {
SQLConf.get.optimizerMaxNumOfLeafExpressionsInCollapsedProject < 0 ||
numberOfLeafExpressions(projectList) <=
SQLConf.get.optimizerMaxNumOfLeafExpressionsInCollapsedProject
}

private def numberOfLeafExpressions(projectList: Seq[Expression]): Long = {
projectList
.map(expr => if (expr.children.nonEmpty) numberOfLeafExpressions(expr.children) else 1)
.sum
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ object SQLConf {
.intConf
.createWithDefault(100)

val OPTIMIZER_MAX_NUM_OF_LEAF_EXPRESSIONS_IN_COLLAPSED_PROJECT =
buildConf("spark.sql.optimizer.maxNumOfLeafExpressionsInCollapsedProject")
.internal()
.doc("Sets the maximum number of leaf expressions that a project is allowed to " +
"have after collapsing. If the collapsed project would have more leaf expressions " +
"than this number then the optimizer won't collapse. Set to -1 to disable.")
.longConf
.createWithDefault(10000)

val OPTIMIZER_INSET_CONVERSION_THRESHOLD =
buildConf("spark.sql.optimizer.inSetConversionThreshold")
.internal()
Expand Down Expand Up @@ -1477,6 +1486,9 @@ class SQLConf extends Serializable with Logging {

def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS)

def optimizerMaxNumOfLeafExpressionsInCollapsedProject: Long =
getConf(OPTIMIZER_MAX_NUM_OF_LEAF_EXPRESSIONS_IN_COLLAPSED_PROJECT)

def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD)

def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,16 @@ class CollapseProjectSuite extends PlanTest {
assert(projects.size === 1)
assert(hasMetadata(optimized))
}

test("do not collapse if number of leave expressions would be too big") {
var query: LogicalPlan = testRelation
for( a <- 1 to 13) {
// after n iterations the number of leaf expressions will be 2^{n+1}
// => after 13 iterations we would end up with more than 10000 leaf expressions
query = query.select(('a + 'b).as('a), ('a - 'b).as('b))
}

val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
assert(projects.size === 2) // everything should be collapsed except the last one
}
}