Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,35 @@ class EquivalentExpressions {
* Adds only expressions which are common in each of given expressions, in a recursive way.
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`,
* the common expression `(c + 1)` will be added into `equivalenceMap`.
*
* Note that as we don't know in advance if any child node of an expression will be common
* across all given expressions, we count all child nodes when looking through the given
* expressions. But when we call `addExprTree` to add common expressions into the map, we
* will add recursively the child nodes. So we need to filter the child expressions first.
* For example, if `((a + b) + c)` and `(a + b)` are common expressions, we only add
* `((a + b) + c)`.
*/
private def addCommonExprs(
exprs: Seq[Expression],
addFunc: Expression => Boolean = addExpr): Unit = {
val exprSetForAll = mutable.Set[Expr]()
Copy link
Contributor

@Kimahriman Kimahriman May 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One potentially unrelated thing I just noticed, do we need to keep track of all of the Expressions here as well (as in an Expr -> Seq[Expression] map)? It's really basically keeping the first Expression found, but the codegen looks like it uses the Expression hash (versus the semantic hash) to lookup subexpressions. Very much an edge case, just wondering if I'm understanding things correctly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean equivalenceMap?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mean add it directly to that here. I'm just thinking of a really stupid example, when((col + 1) > 0, col + 1).otherwise(1 + col). Wouldn't col + 1 and 1 + col resolve as a common expression because they're semantically equal, but only col + 1 is added to equivalenceMap, so during codegen 1 + col wouldn't be resolved to the subexpression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

col + 1 and 1 + col will both be recognized as subexpression.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but won't the codgen stage not replace 1 + col since only col + 1 will be added to the equivalenceMap entry for Expr(col + 1)? For non commonExprs cases, both would be in equivalenceMap so that the codegen stage maps both of those expressions to the resulting subexpression. Again, not super related to this PR, but was the easiest place to ask

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both 1 + col and col + 1 will be replaced with the extracted subexpression during codege. We don't just look of key at equivalenceMap when replacing with subexpression.

addExprTree(exprs.head, addExprToSet(_, exprSetForAll))

val commonExprSet = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
val otherExprSet = mutable.Set[Expr]()
addExprTree(expr, addExprToSet(_, otherExprSet))
exprSet.intersect(otherExprSet)
}

commonExprSet.foreach(expr => addFunc(expr.e))
// Not all expressions in the set should be added. We should filter out the related
// children nodes.
val commonExprSet = candidateExprs.filter { candidateExpr =>
candidateExprs.forall { expr =>
expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this loop not expensive? It seems the time-complexity is big-O(the total number of expr nodes in candidateExprs) x (candidateExprs.size)^2 )?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, but I don't have a better idea now...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I considered this part but didn't come out better one.

Copy link
Member

@maropu maropu May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, okay. I don't have a idea, too... That was just a question.

}

commonExprSet.foreach(expr => addExprTree(expr.e, addFunc))
}

// There are some special expressions that we should not recurse into all of its children.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
}
}

test("SPARK-35410: SubExpr elimination should not include redundant child exprs " +
"for conditional expressions") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this only a problem for conditional expression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far the only one I can think about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found a non-conditional example that still is an issue even with this update (a bit contrived, but I'm sure there's a real use case)

val myUdf = udf(() => {
  println("In UDF")
  1
}).withName("myUdf")
spark.range(1).withColumn("a", myUdf()).select(($"a" + $"a") / ($"a" + $"a")).show()

This generates subexpressions myUdf() and (myUdf() + myUdf()), even though only the second one is used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Kimahriman. I see. Let me also look at it. As it is non-conditional case, but looks like the similar case. Let me see if it can be solved similarly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I figured out. This might be an issue since we have sub-expr elimination. We also need to remove redundant children exprs for non-conditional cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the fix might be different. I will work on it locally and submit another fix for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any more thoughts on this? Was the subexpr sorting supposed to address this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might need another fix. I'm working on it and will submit it after these PRs merged.

val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))
val add3 = Add(add1, add2)
val condition = (GreaterThan(add3, Literal(3)), add3) :: Nil

val caseWhenExpr = CaseWhen(condition, None)
val equivalence = new EquivalentExpressions
equivalence.addExprTree(caseWhenExpr)

val commonExprs = equivalence.getAllEquivalentExprs(1)
assert(commonExprs.size == 1)
assert(commonExprs.head === Seq(add3, add3))
}

test("SPARK-35439: Children subexpr should come first than parent subexpr") {
val add = Add(Literal(1), Literal(2))

Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2861,6 +2861,25 @@ class DataFrameSuite extends QueryTest
checkAnswer(result, Row(0, 0, 0, 0, 100))
}
}

test("SPARK-35410: SubExpr elimination should not include redundant child exprs " +
"for conditional expressions") {
val accum = sparkContext.longAccumulator("call")
val simpleUDF = udf((s: String) => {
accum.add(1)
s
})
val df1 = spark.range(5).select(when(functions.length(simpleUDF($"id")) > 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the fix for https://issues.apache.org/jira/browse/SPARK-35449 will break this, since it's really a "bug" that the case value is included in subexpression resolution without an else value. Not a huge deal, I can try to fix in my follow up once this is merged

functions.length(simpleUDF($"id"))))
df1.collect()
assert(accum.value == 5)

val nondeterministicUDF = simpleUDF.asNondeterministic()
val df2 = spark.range(5).select(when(functions.length(nondeterministicUDF($"id")) > 0,
functions.length(nondeterministicUDF($"id"))))
df2.collect()
assert(accum.value == 15)
}
}

case class GroupByKey(a: Int, b: Int)
Expand Down