-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-35560][SQL] Remove redundant subexpression evaluation in nested subexpressions #32699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Kubernetes integration test starting |
|
Kubernetes integration test status success |
This comment has been minimized.
This comment has been minimized.
|
Two questions just to see if I understand things correctly: this builds upon/relies upon the common expressions being sorted from your previous MR right? And theoretically if #32559 hadn't been fixed, this would sort of address that issue by creating a subexpression function for each level that builds on the previous instead of starting from scratch? |
Yes, if you're asking if this is built on previous PR that sorts subexpressions.
Yea, if #32559 is not fixed, this could somehow relieve that but it still suffers from redundant subexpression call if it is split case. |
|
Kubernetes integration test starting |
|
Kubernetes integration test status success |
|
Test build #139076 has finished for PR 32699 at commit
|
Hah whoops I spend most of my day dealing with merge requests not pull requests 😅 |
| exprs.foreach(localSubExprEliminationExprs.put(_, state)) | ||
| val nonSplitExprCode = { | ||
| commonExprs.map { exprs => | ||
| val eval = withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this recursive call of withSubExprEliminationExprs give us?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For each set of common expressions, withSubExprEliminationExprs only called once so I think it is not actually a recursive call?
withSubExprEliminationExprs takes the given map used for subexpression elimination to replace common expression during expression codegen in the closure. It returns evaluated expression code (value/isNull/code).
For the two subexpressions as example:
simpleUDF($"id")functions.length(simpleUDF($"id"))
1st round withSubExprEliminationExprs:
The map is empty.
Gen code for simpleUDF($"id").
Put it into the map => (simpleUDF($"id") -> gen-ed code)
2nd round withSubExprEliminationExprs:
Gen code for functions.length(simpleUDF($"id")).
Looking at the map and replace common expression simpleUDF($"id") as gen-ed code.
Put it into the map => (simpleUDF($"id") -> gen-ed code, functions.length(simpleUDF($"id")) -> gen-ed code)
The map will be used later for subexpression elimination.
| Seq(expr.genCode(this)) | ||
| }.head | ||
|
|
||
| val value = addMutableState(javaType(expr.dataType), "subExprValue") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this have to be a mutable state now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the example in the description to explain. For the two subexpressions:
simpleUDF($"id")functions.length(simpleUDF($"id"))
Previously we evaluate them independently, i.e.,
String subExpr1 = simpleUDF($"id");
Int subExpr2 = functions.length(simpleUDF($"id"));
Now we remove redundant evaluation of nested subexpressions:
String subExpr1 = simpleUDF($"id");
Int subExpr2 = functions.length(subExpr1);
If we need to split the functions, when we evaluate functions.length, it needs access of subExpr1. We have two choices. One is to add subExpr1 to the function parameter list of the split function for functions.length. Another one is to use mutable state.
To add it to parameter list will complicate the way we compute parameter length. That's said we need to link nested subexpression relations and get the correct parameters. Seems to me it is not worth doing that.
Currently I choose the simpler approach that is to use mutable state.
|
|
||
| // Common exprs: | ||
| // 1. simpleUDF($"id") | ||
| // 2. functions.length(simpleUDF($"id")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: What if a tree has more deeply-nested common exprs? The current logic can work well? e.g., I thought it like this;
// subExpr1 = simpleUDF($"id");
// subExpr2 = functions.length(subExpr1);
// subExpr3 = functions.xxxx(subExpr2);
// subExpr4 = ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is actually the cases this logic to deal with. Previous common expression gen-ed codes will put into the map. The code generator looks up into the map when generating code for later common expressions to replace the semantic-equal expression with gen-ed code value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
|
I left one question and the other part looks fine. |
|
Thanks @cloud-fan @maropu. I will merge this later if no more comments. |
|
Thank you @dongjoon-hyun |
|
Thanks all! Merging to master. |
…d subexpressions
This patch proposes to improve subexpression evaluation under whole-stage codegen for the cases of nested subexpressions.
In the cases of nested subexpressions, whole-stage codegen's subexpression elimination will do redundant subexpression evaluation. We should reduce it. For example, if we have two sub-exprs:
1. `simpleUDF($"id")`
2. `functions.length(simpleUDF($"id"))`
We should only evaluate `simpleUDF($"id")` once, i.e.
```java
subExpr1 = simpleUDF($"id");
subExpr2 = functions.length(subExpr1);
```
Snippets of generated codes:
Before:
```java
/* 040 */ private int project_subExpr_1(long project_expr_0_0) {
/* 041 */ boolean project_isNull_6 = false;
/* 042 */ UTF8String project_value_6 = null;
/* 043 */ if (!false) {
/* 044 */ project_value_6 = UTF8String.fromString(String.valueOf(project_expr_0_0));
/* 045 */ }
/* 046 */
/* 047 */ Object project_arg_1 = null;
/* 048 */ if (project_isNull_6) {
/* 049 */ project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(null);
/* 050 */ } else {
/* 051 */ project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(project_value_6); /* 052 */ }
/* 053 */
/* 054 */ UTF8String project_result_1 = null; /* 055 */ try { /* 056 */ project_result_1 = (UTF8String)((scala.Function1[]) references[3] /* converters */)[1].apply(((scala.Function1) references[4] /* udf */).apply(project_arg_1)
);
/* 057 */ } catch (Throwable e) {
/* 058 */ throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
/* 059 */ "DataFrameSuite$$Lambda$6418/1507986601", "string", "string", e);
/* 060 */ }
/* 061 */
/* 062 */ boolean project_isNull_5 = project_result_1 == null;
/* 063 */ UTF8String project_value_5 = null;
/* 064 */ if (!project_isNull_5) {
/* 065 */ project_value_5 = project_result_1;
/* 066 */ }
/* 067 */ boolean project_isNull_4 = project_isNull_5;
/* 068 */ int project_value_4 = -1;
/* 069 */
/* 070 */ if (!project_isNull_5) {
/* 071 */ project_value_4 = (project_value_5).numChars();
/* 072 */ }
/* 073 */ project_subExprIsNull_1 = project_isNull_4;
/* 074 */ return project_value_4;
/* 075 */ }
...
/* 149 */ private UTF8String project_subExpr_0(long project_expr_0_0) {
/* 150 */ boolean project_isNull_2 = false;
/* 151 */ UTF8String project_value_2 = null;
/* 152 */ if (!false) {
/* 153 */ project_value_2 = UTF8String.fromString(String.valueOf(project_expr_0_0));
/* 154 */ }
/* 155 */
/* 156 */ Object project_arg_0 = null;
/* 157 */ if (project_isNull_2) {
/* 158 */ project_arg_0 = ((scala.Function1[]) references[1] /* converters */)[0].apply(null);
/* 159 */ } else {
/* 160 */ project_arg_0 = ((scala.Function1[]) references[1] /* converters */)[0].apply(project_value_2);
/* 161 */ }
/* 162 */
/* 163 */ UTF8String project_result_0 = null;
/* 164 */ try {
/* 165 */ project_result_0 = (UTF8String)((scala.Function1[]) references[1] /* converters */)[1].apply(((scala.Function1) references[2] /* udf */).apply(project_arg_0)
);
/* 166 */ } catch (Throwable e) {
/* 167 */ throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
/* 168 */ "DataFrameSuite$$Lambda$6418/1507986601", "string", "string", e);
/* 169 */ }
/* 170 */
/* 171 */ boolean project_isNull_1 = project_result_0 == null; /* 172 */ UTF8String project_value_1 = null; /* 173 */ if (!project_isNull_1) { /* 174 */ project_value_1 = project_result_0;
/* 175 */ }
/* 176 */ project_subExprIsNull_0 = project_isNull_1;
/* 177 */ return project_value_1;
/* 178 */ }
```
After:
```java
/* 041 */ private void project_subExpr_1(long project_expr_0_0) {
/* 042 */ boolean project_isNull_8 = project_subExprIsNull_0;
/* 043 */ int project_value_8 = -1;
/* 044 */
/* 045 */ if (!project_subExprIsNull_0) {
/* 046 */ project_value_8 = (project_mutableStateArray_0[0]).numChars();
/* 047 */ }
/* 048 */ project_subExprIsNull_1 = project_isNull_8;
/* 049 */ project_subExprValue_0 = project_value_8;
/* 050 */ }
/* 056 */
...
/* 123 */
/* 124 */ private void project_subExpr_0(long project_expr_0_0) {
/* 125 */ boolean project_isNull_6 = false;
/* 126 */ UTF8String project_value_6 = null;
/* 127 */ if (!false) {
/* 128 */ project_value_6 = UTF8String.fromString(String.valueOf(project_expr_0_0));
/* 129 */ }
/* 130 */
/* 131 */ Object project_arg_1 = null;
/* 132 */ if (project_isNull_6) {
/* 133 */ project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(null);
/* 134 */ } else {
/* 135 */ project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(project_value_6);
/* 136 */ }
/* 137 */
/* 138 */ UTF8String project_result_1 = null;
/* 139 */ try {
/* 140 */ project_result_1 = (UTF8String)((scala.Function1[]) references[3] /* converters */)[1].apply(((scala.Function1) references[4] /* udf */).apply(project_arg_1)
);
/* 141 */ } catch (Throwable e) {
/* 142 */ throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
/* 143 */ "DataFrameSuite$$Lambda$6430/2004847941", "string", "string", e);
/* 144 */ }
/* 145 */
/* 146 */ boolean project_isNull_5 = project_result_1 == null;
/* 147 */ UTF8String project_value_5 = null;
/* 148 */ if (!project_isNull_5) {
/* 149 */ project_value_5 = project_result_1;
/* 150 */ }
/* 151 */ project_subExprIsNull_0 = project_isNull_5;
/* 152 */ project_mutableStateArray_0[0] = project_value_5;
/* 153 */ }
```
No
Unit test.
Closes apache#32699 from viirya/improve-subexpr.
Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
What changes were proposed in this pull request?
This patch proposes to improve subexpression evaluation under whole-stage codegen for the cases of nested subexpressions.
Why are the changes needed?
In the cases of nested subexpressions, whole-stage codegen's subexpression elimination will do redundant subexpression evaluation. We should reduce it. For example, if we have two sub-exprs:
simpleUDF($"id")functions.length(simpleUDF($"id"))We should only evaluate
simpleUDF($"id")once, i.e.Snippets of generated codes:
Before:
After:
Does this PR introduce any user-facing change?
No
How was this patch tested?
Unit test.