Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,40 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}

// If a condition in a branch is previously seen, this branch can be pruned.
// TODO: In fact, if a condition is a sub-condition of the previous one,
// TODO: it can be pruned. This is less strict and can be implemented
// TODO: by decomposing the seen conditions.
private def pruneSeenBranches(branches: Seq[(Expression, Expression)])
: Option[Seq[(Expression, Expression)]] = {
val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) {
case (newBranches, branch) if newBranches.exists(_._1.semanticEquals(branch._1)) =>
newBranches
case (newBranches, branch) => newBranches += branch
}
if (newBranches.length < branches.length) {
Some(newBranches)
} else {
None
}
}

// If the outputs of two adjacent branches are the same, two branches can be combined.
private def combineAdjacentBranches(branches: Seq[(Expression, Expression)])
: Option[Seq[(Expression, Expression)]] = {
val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) {
case (newBranches, branch)
if newBranches.nonEmpty && newBranches.last._2.semanticEquals(branch._2) =>
newBranches.init += ((Or(newBranches.last._1, branch._1), newBranches.last._2))
case (newBranches, branch) => newBranches += branch
}
if (newBranches.length < branches.length) {
Some(newBranches)
} else {
None
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
Expand Down Expand Up @@ -416,6 +450,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
// these branches can be pruned away
val (h, t) = branches.span(_._1 != TrueLiteral)
CaseWhen( h :+ t.head, None)

case e @ CaseWhen(branches, _) if pruneSeenBranches(branches).nonEmpty =>
e.copy(branches = pruneSeenBranches(branches).get)

case e @ CaseWhen(branches, _) if combineAdjacentBranches(branches).nonEmpty =>
e.copy(branches = combineAdjacentBranches(branches).get)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe:

case e @ CaseWhen(branches, _) =>
  val prunedBranches = pruneSeenBranches(branches).getOrElse(branches)
  val newBranches = combineAdjacentBranches(prunedBranches).getOrElse(prunedBranches)
  if (newBranches.length < branches.length) {
    e.copy(branches = newBranches)
  } else {
    e
  }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If pruneSeenBranches and combineAdjacentBranches both return input branches if no changes applied, this maybe more simplified:

case e @ CaseWhen(branches, _) =>
  val prunedBranches = pruneSeenBranches(branches)
  val newBranches = combineAdjacentBranches(prunedBranches)
  if (newBranches.length < branches.length) {
    e.copy(branches = newBranches)
  } else {
    e
  }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will work. But my only concern is that this will be caught all the CaseWhen cases; as a result, no more rule can be added after this.

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
private val unreachableBranch = (FalseLiteral, Literal(20))
private val nullBranch = (Literal.create(null, NullType), Literal(30))

private val testRelation = LocalRelation('a.int)
val isNotNullCond = IsNotNull(UnresolvedAttribute("a"))
val isNullCond = IsNull(UnresolvedAttribute("a"))
val notCond = Not(UnresolvedAttribute("c"))

test("simplify if") {
assertEquivalent(
Expand Down Expand Up @@ -122,4 +124,54 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
None),
CaseWhen(normalBranch :: trueBranch :: Nil, None))
}

test("remove a branch in CaseWhen if a cond in this branch is previously seen") {
assertEquivalent(
CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) ::
(GreaterThan(Rand(0), Literal(0.5)), Literal(2)) ::
(NonFoldableLiteral(true), Literal(3)) ::
(LessThan(Rand(1), Literal(0.5)), Literal(4)) ::
(NonFoldableLiteral(true), Literal(5)) ::
(NonFoldableLiteral(false), Literal(6)) ::
(NonFoldableLiteral(false), Literal(7)) ::
Nil,
None),
CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) ::
(GreaterThan(Rand(0), Literal(0.5)), Literal(2)) ::
(NonFoldableLiteral(true), Literal(3)) ::
(LessThan(Rand(1), Literal(0.5)), Literal(4)) ::
(NonFoldableLiteral(false), Literal(6)) ::
Nil,
None)
)
}

test("combine two adjacent branches in CaseWhen if they have the same output values") {
assertEquivalent(
CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) ::
(NonFoldableLiteral(true), Literal(1)) ::
(LessThan(Rand(1), Literal(0.5)), Literal(3)) ::
(NonFoldableLiteral(true), Literal(3)) ::
(NonFoldableLiteral(false), Literal(4)) ::
Nil,
None),
CaseWhen((Or(GreaterThan(Rand(0), Literal(0.5)), NonFoldableLiteral(true)), Literal(1)) ::
(LessThan(Rand(1), Literal(0.5)), Literal(3)) ::
(NonFoldableLiteral(false), Literal(4)) ::
Nil,
None)
)

// The first two conditions can be combined, and then the optimizer uses rule in `Or`
// to be optimized into `TrueLiteral`. Thus, the entire `CaseWhen` can be removed.
assertEquivalent(
CaseWhen((UnresolvedAttribute("a"), Literal(1)) ::
(Not(UnresolvedAttribute("a")), Literal(1)) ::
(LessThan(Rand(1), Literal(0.5)), Literal(3)) ::
(NonFoldableLiteral(true), Literal(4)) ::
(NonFoldableLiteral(false), Literal(5)) ::
Nil,
None),
Literal(1))
}
}