Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
}

override protected val blacklistedOnceBatches: Set[String] =
Set("Pullup Correlated Expressions",
"Extract Python UDFs"
)
Set("Extract Python UDFs")

protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,28 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
}

private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = {
/**
* This function is used as a aid to enforce idempotency of pullUpCorrelatedPredicate rule.
* In the first call to rewriteSubqueries, all the outer references from the subplan are
* pulled up and join predicates are recorded as children of the enclosing subquery expression.
* The subsequent call to rewriteSubqueries would simply re-records the `children` which would
* contains the pulled up correlated predicates (from the previous call) in the enclosing
* subquery expression.
*/
def getJoinCondition(newCond: Seq[Expression], oldCond: Seq[Expression]): Seq[Expression] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true that this function is to workaround cases where the newCond is empty, while outer references is not empty? Maybe we should add some comment here since it might be tricky to understand...

if (newCond.isEmpty) oldCond else newCond
}

plan transformExpressions {
case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
ScalarSubquery(newPlan, newCond, exprId)
ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId)
case Exists(sub, children, exprId) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
Exists(newPlan, newCond, exprId)
case ListQuery(sub, _, exprId, childOutputs) =>
Exists(newPlan, getJoinCondition(newCond, children), exprId)
case ListQuery(sub, children, exprId, childOutputs) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
ListQuery(newPlan, newCond, exprId, childOutputs)
ListQuery(newPlan, getJoinCondition(newCond, children), exprId, childOutputs)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PullupCorrelatedPredicatesSuite extends PlanTest {
Expand All @@ -38,17 +38,65 @@ class PullupCorrelatedPredicatesSuite extends PlanTest {
val testRelation2 = LocalRelation('c.int, 'd.double)

test("PullupCorrelatedPredicates should not produce unresolved plan") {
val correlatedSubquery =
val subPlan =
testRelation2
.where('b < 'd)
.select('c)
val outerQuery =
val inSubquery =
testRelation
.where(InSubquery(Seq('a), ListQuery(correlatedSubquery)))
.where(InSubquery(Seq('a), ListQuery(subPlan)))
.select('a).analyze
assert(outerQuery.resolved)
assert(inSubquery.resolved)

val optimized = Optimize.execute(outerQuery)
val optimized = Optimize.execute(inSubquery)
assert(optimized.resolved)
}

test("PullupCorrelatedPredicates in correlated subquery idempotency check") {
val subPlan =
testRelation2
.where('b < 'd)
.select('c)
val inSubquery =
testRelation
.where(InSubquery(Seq('a), ListQuery(subPlan)))
.select('a).analyze
assert(inSubquery.resolved)

val optimized = Optimize.execute(inSubquery)
val doubleOptimized = Optimize.execute(optimized)
comparePlans(optimized, doubleOptimized)
}

test("PullupCorrelatedPredicates exists correlated subquery idempotency check") {
val subPlan =
testRelation2
.where('b === 'd && 'd === 1)
.select(Literal(1))
val existsSubquery =
testRelation
.where(Exists(subPlan))
.select('a).analyze
assert(existsSubquery.resolved)

val optimized = Optimize.execute(existsSubquery)
val doubleOptimized = Optimize.execute(optimized)
comparePlans(optimized, doubleOptimized)
}

test("PullupCorrelatedPredicates scalar correlated subquery idempotency check") {
val subPlan =
testRelation2
.where('b === 'd && 'd === 1)
.select(max('d))
val scalarSubquery =
testRelation
.where(ScalarSubquery(subPlan))
.select('a).analyze
assert(scalarSubquery.resolved)

val optimized = Optimize.execute(scalarSubquery)
val doubleOptimized = Optimize.execute(optimized)
comparePlans(optimized, doubleOptimized, false)
}
}