Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -417,68 +417,57 @@ object ColumnPruning extends Rule[LogicalPlan] {
object CollapseProject extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p @ Project(projectList1, Project(projectList2, child)) =>
// Create a map of Aliases to their values from the child projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliasMap = AttributeMap(projectList2.collect {
case a: Alias => (a.toAttribute, a)
})

// We only collapse these two Projects if their overlapped expressions are all
// deterministic.
val hasNondeterministic = projectList1.exists(_.collect {
case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
}.exists(!_.deterministic))

if (hasNondeterministic) {
case p1 @ Project(_, p2: Project) =>
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
p1
} else {
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
}
case p @ Project(_, agg: Aggregate) =>
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
p
} else {
// Substitute any attributes that are produced by the child projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
// TODO: Fix TransformBase to avoid the cast below.
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
}).asInstanceOf[Seq[NamedExpression]]
// collapse 2 projects may introduce unnecessary Aliases, trim them here.
val cleanedProjection = substitutedProjection.map(p =>
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
)
Project(cleanedProjection, child)
agg.copy(aggregateExpressions = buildCleanedProjectList(
p.projectList, agg.aggregateExpressions))
}
}

// TODO Eliminate duplicate code
// This clause is identical to the one above except that the inner operator is an `Aggregate`
// rather than a `Project`.
case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) =>
// Create a map of Aliases to their values from the child projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliasMap = AttributeMap(projectList2.collect {
case a: Alias => (a.toAttribute, a)
})
private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
AttributeMap(projectList.collect {
case a: Alias => a.toAttribute -> a
})
}

// We only collapse these two Projects if their overlapped expressions are all
// deterministic.
val hasNondeterministic = projectList1.exists(_.collect {
case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
}.exists(!_.deterministic))
private def haveCommonNonDeterministicOutput(
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
// Create a map of Aliases to their values from the lower projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliases = collectAliases(lower)

// Collapse upper and lower Projects if and only if their overlapped expressions are all
// deterministic.
upper.exists(_.collect {
case a: Attribute if aliases.contains(a) => aliases(a).child
}.exists(!_.deterministic))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further refactored your code a little bit:

def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
  AttributeMap(projectList.collect {
    case a: Alias => a.toAttribute -> a
  })
}

def haveCommonNonDeterministicOutput(
    upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
  val aliases = collectAliases(lower)
  upper.exists(_.collect {
    case a: Attribute if aliases.contains(a) => aliases(a).child
  }).exists(!_.deterministic)
}

def buildCleanedProjectList(
    upper: Seq[NamedExpression],
    lower: Seq[NamedExpression]): Seq[NamedExpression] = {
  val aliases = collectAliases(lower)

  val rewrittenUpper = upper.map(_.transform {
    case a: Attribute => aliases.getOrElse(a, a)
  })

  rewrittenUpper.map { p =>
    CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
  }
}

And those inline comments need some rewording as they are now moved to different contexts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! will do


if (hasNondeterministic) {
p
} else {
// Substitute any attributes that are produced by the child projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
// TODO: Fix TransformBase to avoid the cast below.
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
}).asInstanceOf[Seq[NamedExpression]]
// collapse 2 projects may introduce unnecessary Aliases, trim them here.
val cleanedProjection = substitutedProjection.map(p =>
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
)
agg.copy(aggregateExpressions = cleanedProjection)
}
private def buildCleanedProjectList(
upper: Seq[NamedExpression],
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
// Create a map of Aliases to their values from the lower projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliases = collectAliases(lower)

// Substitute any attributes that are produced by the lower projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
val rewrittenUpper = upper.map(_.transform {
case a: Attribute => aliases.getOrElse(a, a)
})
// collapse upper and lower Projects may introduce unnecessary Aliases, trim them here.
rewrittenUpper.map { p =>
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CollapseProjectSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", FixedPoint(10), EliminateSubqueryAliases) ::
Batch("CollapseProject", Once, CollapseProject) :: Nil
Batch("CollapseProject", Once, CollapseProject) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int)
Expand Down Expand Up @@ -95,4 +95,28 @@ class CollapseProjectSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("collapse project into aggregate") {
val query = testRelation
.groupBy('a, 'b)(('a + 1).as('a_plus_1), 'b)
.select('a_plus_1, ('b + 1).as('b_plus_1))

val optimized = Optimize.execute(query.analyze)

val correctAnswer = testRelation
.groupBy('a, 'b)(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

comparePlans(optimized, correctAnswer)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to add one more test case that contains common non-deterministic fields.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do it.


test("do not collapse common nondeterministic project and aggregate") {
val query = testRelation
.groupBy('a)('a, Rand(10).as('rand))
.select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

val optimized = Optimize.execute(query.analyze)
val correctAnswer = query.analyze

comparePlans(optimized, correctAnswer)
}
}