Skip to content

Conversation

@viirya
Copy link
Member

@viirya viirya commented May 28, 2021

What changes were proposed in this pull request?

This patch proposes to improve subexpression evaluation under whole-stage codegen for the cases of nested subexpressions.

Why are the changes needed?

In the cases of nested subexpressions, whole-stage codegen's subexpression elimination will do redundant subexpression evaluation. We should reduce it. For example, if we have two sub-exprs:

  1. simpleUDF($"id")
  2. functions.length(simpleUDF($"id"))

We should only evaluate simpleUDF($"id") once, i.e.

subExpr1 = simpleUDF($"id");
subExpr2 = functions.length(subExpr1);

Snippets of generated codes:

Before:

/* 040 */   private int project_subExpr_1(long project_expr_0_0) {                                                                                                           
/* 041 */     boolean project_isNull_6 = false;                                                                                                                              
/* 042 */     UTF8String project_value_6 = null;                                      
/* 043 */     if (!false) {                                                                                                                                                  
/* 044 */       project_value_6 = UTF8String.fromString(String.valueOf(project_expr_0_0));                                                                                   
/* 045 */     }                                                                                                                                                              
/* 046 */                                                                                                                                                                    
/* 047 */     Object project_arg_1 = null;                                            
/* 048 */     if (project_isNull_6) {                                                                                                                                        
/* 049 */       project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(null);                                                                         
/* 050 */     } else {                                                                                                                                                       
/* 051 */       project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(project_value_6);                                                              /* 052 */     }                                                                                                                                                              
/* 053 */                                                                                                                                                                    
/* 054 */     UTF8String project_result_1 = null;                                                                                                                            /* 055 */     try {                                                                                                                                                          /* 056 */       project_result_1 = (UTF8String)((scala.Function1[]) references[3] /* converters */)[1].apply(((scala.Function1) references[4] /* udf */).apply(project_arg_1)
);                                                                                                                                                                           
/* 057 */     } catch (Throwable e) {                                                 
/* 058 */       throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(                                                                                            
/* 059 */         "DataFrameSuite$$Lambda$6418/1507986601", "string", "string", e);                                                                                          
/* 060 */     }                                                                       
/* 061 */                                                                             
/* 062 */     boolean project_isNull_5 = project_result_1 == null;
/* 063 */     UTF8String project_value_5 = null;
/* 064 */     if (!project_isNull_5) {
/* 065 */       project_value_5 = project_result_1;
/* 066 */     }                                                                                                                                                              
/* 067 */     boolean project_isNull_4 = project_isNull_5;                                                                                                                   
/* 068 */     int project_value_4 = -1;                                                                                                                                      
/* 069 */                                                                                                                                                                    
/* 070 */     if (!project_isNull_5) {                                                                                                                                       
/* 071 */       project_value_4 = (project_value_5).numChars();                                                                                                              
/* 072 */     }                                                                                                                                                              
/* 073 */     project_subExprIsNull_1 = project_isNull_4;                                                                                                                    
/* 074 */     return project_value_4;                                                                                                                                        
/* 075 */   }    
...                          
/* 149 */   private UTF8String project_subExpr_0(long project_expr_0_0) {                                                                                                    
/* 150 */     boolean project_isNull_2 = false;                                                                                                                              
/* 151 */     UTF8String project_value_2 = null;                                                                                                                             
/* 152 */     if (!false) {                                                                                                                                                  
/* 153 */       project_value_2 = UTF8String.fromString(String.valueOf(project_expr_0_0));                                                                                   
/* 154 */     }                                                                                                                                                              
/* 155 */                                                                                                                                                                    
/* 156 */     Object project_arg_0 = null;                                                                                                                                   
/* 157 */     if (project_isNull_2) {                                                                                                                                        
/* 158 */       project_arg_0 = ((scala.Function1[]) references[1] /* converters */)[0].apply(null);                                                                         
/* 159 */     } else {                                                                                                                                                       
/* 160 */       project_arg_0 = ((scala.Function1[]) references[1] /* converters */)[0].apply(project_value_2);                                                              
/* 161 */     }                                                                       
/* 162 */                                                                                                                                                                    
/* 163 */     UTF8String project_result_0 = null;                                                                                                                            
/* 164 */     try {                                                                                                                                                          
/* 165 */       project_result_0 = (UTF8String)((scala.Function1[]) references[1] /* converters */)[1].apply(((scala.Function1) references[2] /* udf */).apply(project_arg_0)
);                                                                                    
/* 166 */     } catch (Throwable e) {                                                                                                                                        
/* 167 */       throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(                                                                                            
/* 168 */         "DataFrameSuite$$Lambda$6418/1507986601", "string", "string", e);                                                                                          
/* 169 */     }                                                                                                                                                              
/* 170 */                                                                                                                                                                    
/* 171 */     boolean project_isNull_1 = project_result_0 == null;                                                                                                           /* 172 */     UTF8String project_value_1 = null;                                                                                                                             /* 173 */     if (!project_isNull_1) {                                                                                                                                       /* 174 */       project_value_1 = project_result_0;                                                                                                                          
/* 175 */     }                                                                                                                                                              
/* 176 */     project_subExprIsNull_0 = project_isNull_1;                             
/* 177 */     return project_value_1;                                                                                                                                        
/* 178 */   }                                     

After:

/* 041 */   private void project_subExpr_1(long project_expr_0_0) {                                                                                                          
/* 042 */     boolean project_isNull_8 = project_subExprIsNull_0;                                                                                                            
/* 043 */     int project_value_8 = -1;                                                                                                                                      
/* 044 */                                                                                                                                                                    
/* 045 */     if (!project_subExprIsNull_0) {                                                                                                                                
/* 046 */       project_value_8 = (project_mutableStateArray_0[0]).numChars();                                                                                               
/* 047 */     }                                                                                                                                                              
/* 048 */     project_subExprIsNull_1 = project_isNull_8;                                                                                                                    
/* 049 */     project_subExprValue_0 = project_value_8;                                                                                                                      
/* 050 */   }                                                                                                                                           
/* 056 */                                                                                                                                           
...                                                                                                                                                                                                                                                                                                                    
/* 123 */                                                                                                                                                                    
/* 124 */   private void project_subExpr_0(long project_expr_0_0) {                                                                                                          
/* 125 */     boolean project_isNull_6 = false;                                                                                                                              
/* 126 */     UTF8String project_value_6 = null;                                                                                                                             
/* 127 */     if (!false) {                                                                                                                                                  
/* 128 */       project_value_6 = UTF8String.fromString(String.valueOf(project_expr_0_0));                                                                                   
/* 129 */     }                                                                                                                                                              
/* 130 */                                                                             
/* 131 */     Object project_arg_1 = null;                                            
/* 132 */     if (project_isNull_6) {                                                 
/* 133 */       project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(null);                                                                         
/* 134 */     } else {                                                                
/* 135 */       project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(project_value_6);
/* 136 */     }                                                                       
/* 137 */                                                                                                                                                                    
/* 138 */     UTF8String project_result_1 = null;                                                                                                                            
/* 139 */     try {                                                                   
/* 140 */       project_result_1 = (UTF8String)((scala.Function1[]) references[3] /* converters */)[1].apply(((scala.Function1) references[4] /* udf */).apply(project_arg_1)
);                                                                                    
/* 141 */     } catch (Throwable e) {                                                 
/* 142 */       throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(     
/* 143 */         "DataFrameSuite$$Lambda$6430/2004847941", "string", "string", e);
/* 144 */     }                                                                                                                                                              
/* 145 */                                                                             
/* 146 */     boolean project_isNull_5 = project_result_1 == null;
/* 147 */     UTF8String project_value_5 = null;                                      
/* 148 */     if (!project_isNull_5) {                                                                                                                                       
/* 149 */       project_value_5 = project_result_1;                                 
/* 150 */     }                                                                       
/* 151 */     project_subExprIsNull_0 = project_isNull_5;          
/* 152 */     project_mutableStateArray_0[0] = project_value_5;                                                                                                              
/* 153 */   }                                                          

Does this PR introduce any user-facing change?

No

How was this patch tested?

Unit test.

@github-actions github-actions bot added the SQL label May 28, 2021
@viirya viirya force-pushed the improve-subexpr branch from 9c1b849 to 8c540cb Compare May 28, 2021 23:31
@SparkQA
Copy link

SparkQA commented May 29, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/43596/

@SparkQA
Copy link

SparkQA commented May 29, 2021

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/43596/

@SparkQA

This comment has been minimized.

@Kimahriman
Copy link
Contributor

Two questions just to see if I understand things correctly: this builds upon/relies upon the common expressions being sorted from your previous MR right? And theoretically if #32559 hadn't been fixed, this would sort of address that issue by creating a subexpression function for each level that builds on the previous instead of starting from scratch?

@viirya viirya force-pushed the improve-subexpr branch from 8c540cb to f1f64f7 Compare May 29, 2021 04:01
@viirya
Copy link
Member Author

viirya commented May 29, 2021

this builds upon/relies upon the common expressions being sorted from your previous MR right?

Yes, if you're asking if this is built on previous PR that sorts subexpressions.

And theoretically if #32559 hadn't been fixed, this would sort of address that issue by creating a subexpression function for each level that builds on the previous instead of starting from scratch?

Yea, if #32559 is not fixed, this could somehow relieve that but it still suffers from redundant subexpression call if it is split case.

@SparkQA
Copy link

SparkQA commented May 29, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/43597/

@SparkQA
Copy link

SparkQA commented May 29, 2021

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/43597/

@SparkQA
Copy link

SparkQA commented May 29, 2021

Test build #139076 has finished for PR 32699 at commit f1f64f7.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member Author

viirya commented May 29, 2021

cc @maropu @cloud-fan

@Kimahriman
Copy link
Contributor

Yes, if you're asking if this is built on previous PR that sorts subexpressions.

Hah whoops I spend most of my day dealing with merge requests not pull requests 😅

exprs.foreach(localSubExprEliminationExprs.put(_, state))
val nonSplitExprCode = {
commonExprs.map { exprs =>
val eval = withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this recursive call of withSubExprEliminationExprs give us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each set of common expressions, withSubExprEliminationExprs only called once so I think it is not actually a recursive call?

withSubExprEliminationExprs takes the given map used for subexpression elimination to replace common expression during expression codegen in the closure. It returns evaluated expression code (value/isNull/code).

For the two subexpressions as example:

  1. simpleUDF($"id")
  2. functions.length(simpleUDF($"id"))

1st round withSubExprEliminationExprs:

The map is empty.
Gen code for simpleUDF($"id").
Put it into the map => (simpleUDF($"id") -> gen-ed code)

2nd round withSubExprEliminationExprs:

Gen code for functions.length(simpleUDF($"id")).
Looking at the map and replace common expression simpleUDF($"id") as gen-ed code.
Put it into the map => (simpleUDF($"id") -> gen-ed code, functions.length(simpleUDF($"id")) -> gen-ed code)

The map will be used later for subexpression elimination.

Seq(expr.genCode(this))
}.head

val value = addMutableState(javaType(expr.dataType), "subExprValue")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this have to be a mutable state now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the example in the description to explain. For the two subexpressions:

  1. simpleUDF($"id")
  2. functions.length(simpleUDF($"id"))

Previously we evaluate them independently, i.e.,

String subExpr1 = simpleUDF($"id");
Int subExpr2 = functions.length(simpleUDF($"id"));

Now we remove redundant evaluation of nested subexpressions:

String subExpr1 = simpleUDF($"id");
Int subExpr2 = functions.length(subExpr1);

If we need to split the functions, when we evaluate functions.length, it needs access of subExpr1. We have two choices. One is to add subExpr1 to the function parameter list of the split function for functions.length. Another one is to use mutable state.

To add it to parameter list will complicate the way we compute parameter length. That's said we need to link nested subexpression relations and get the correct parameters. Seems to me it is not worth doing that.

Currently I choose the simpler approach that is to use mutable state.


// Common exprs:
// 1. simpleUDF($"id")
// 2. functions.length(simpleUDF($"id"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: What if a tree has more deeply-nested common exprs? The current logic can work well? e.g., I thought it like this;

        // subExpr1 = simpleUDF($"id");
        // subExpr2 = functions.length(subExpr1);
        // subExpr3 = functions.xxxx(subExpr2);
        // subExpr4 = ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is actually the cases this logic to deal with. Previous common expression gen-ed codes will put into the map. The code generator looks up into the map when generating code for later common expressions to replace the semantic-equal expression with gen-ed code value.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

@maropu
Copy link
Member

maropu commented Jun 1, 2021

I left one question and the other part looks fine.

@viirya
Copy link
Member Author

viirya commented Jun 1, 2021

Thanks @cloud-fan @maropu. I will merge this later if no more comments.

@viirya
Copy link
Member Author

viirya commented Jun 1, 2021

Thank you @dongjoon-hyun

@viirya
Copy link
Member Author

viirya commented Jun 2, 2021

Thanks all! Merging to master.

@viirya viirya closed this in dbf0b50 Jun 2, 2021
@viirya viirya deleted the improve-subexpr branch June 2, 2021 02:14
Kimahriman pushed a commit to Kimahriman/spark that referenced this pull request Feb 22, 2022
…d subexpressions

This patch proposes to improve subexpression evaluation under whole-stage codegen for the cases of nested subexpressions.

In the cases of nested subexpressions, whole-stage codegen's subexpression elimination will do redundant subexpression evaluation. We should reduce it. For example, if we have two sub-exprs:

1. `simpleUDF($"id")`
2. `functions.length(simpleUDF($"id"))`

We should only evaluate `simpleUDF($"id")` once, i.e.

```java
subExpr1 = simpleUDF($"id");
subExpr2 = functions.length(subExpr1);
```

Snippets of generated codes:

Before:
```java
/* 040 */   private int project_subExpr_1(long project_expr_0_0) {
/* 041 */     boolean project_isNull_6 = false;
/* 042 */     UTF8String project_value_6 = null;
/* 043 */     if (!false) {
/* 044 */       project_value_6 = UTF8String.fromString(String.valueOf(project_expr_0_0));
/* 045 */     }
/* 046 */
/* 047 */     Object project_arg_1 = null;
/* 048 */     if (project_isNull_6) {
/* 049 */       project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(null);
/* 050 */     } else {
/* 051 */       project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(project_value_6);                                                              /* 052 */     }
/* 053 */
/* 054 */     UTF8String project_result_1 = null;                                                                                                                            /* 055 */     try {                                                                                                                                                          /* 056 */       project_result_1 = (UTF8String)((scala.Function1[]) references[3] /* converters */)[1].apply(((scala.Function1) references[4] /* udf */).apply(project_arg_1)
);
/* 057 */     } catch (Throwable e) {
/* 058 */       throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
/* 059 */         "DataFrameSuite$$Lambda$6418/1507986601", "string", "string", e);
/* 060 */     }
/* 061 */
/* 062 */     boolean project_isNull_5 = project_result_1 == null;
/* 063 */     UTF8String project_value_5 = null;
/* 064 */     if (!project_isNull_5) {
/* 065 */       project_value_5 = project_result_1;
/* 066 */     }
/* 067 */     boolean project_isNull_4 = project_isNull_5;
/* 068 */     int project_value_4 = -1;
/* 069 */
/* 070 */     if (!project_isNull_5) {
/* 071 */       project_value_4 = (project_value_5).numChars();
/* 072 */     }
/* 073 */     project_subExprIsNull_1 = project_isNull_4;
/* 074 */     return project_value_4;
/* 075 */   }
...
/* 149 */   private UTF8String project_subExpr_0(long project_expr_0_0) {
/* 150 */     boolean project_isNull_2 = false;
/* 151 */     UTF8String project_value_2 = null;
/* 152 */     if (!false) {
/* 153 */       project_value_2 = UTF8String.fromString(String.valueOf(project_expr_0_0));
/* 154 */     }
/* 155 */
/* 156 */     Object project_arg_0 = null;
/* 157 */     if (project_isNull_2) {
/* 158 */       project_arg_0 = ((scala.Function1[]) references[1] /* converters */)[0].apply(null);
/* 159 */     } else {
/* 160 */       project_arg_0 = ((scala.Function1[]) references[1] /* converters */)[0].apply(project_value_2);
/* 161 */     }
/* 162 */
/* 163 */     UTF8String project_result_0 = null;
/* 164 */     try {
/* 165 */       project_result_0 = (UTF8String)((scala.Function1[]) references[1] /* converters */)[1].apply(((scala.Function1) references[2] /* udf */).apply(project_arg_0)
);
/* 166 */     } catch (Throwable e) {
/* 167 */       throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
/* 168 */         "DataFrameSuite$$Lambda$6418/1507986601", "string", "string", e);
/* 169 */     }
/* 170 */
/* 171 */     boolean project_isNull_1 = project_result_0 == null;                                                                                                           /* 172 */     UTF8String project_value_1 = null;                                                                                                                             /* 173 */     if (!project_isNull_1) {                                                                                                                                       /* 174 */       project_value_1 = project_result_0;
/* 175 */     }
/* 176 */     project_subExprIsNull_0 = project_isNull_1;
/* 177 */     return project_value_1;
/* 178 */   }
```

After:
```java
/* 041 */   private void project_subExpr_1(long project_expr_0_0) {
/* 042 */     boolean project_isNull_8 = project_subExprIsNull_0;
/* 043 */     int project_value_8 = -1;
/* 044 */
/* 045 */     if (!project_subExprIsNull_0) {
/* 046 */       project_value_8 = (project_mutableStateArray_0[0]).numChars();
/* 047 */     }
/* 048 */     project_subExprIsNull_1 = project_isNull_8;
/* 049 */     project_subExprValue_0 = project_value_8;
/* 050 */   }
/* 056 */
...
/* 123 */
/* 124 */   private void project_subExpr_0(long project_expr_0_0) {
/* 125 */     boolean project_isNull_6 = false;
/* 126 */     UTF8String project_value_6 = null;
/* 127 */     if (!false) {
/* 128 */       project_value_6 = UTF8String.fromString(String.valueOf(project_expr_0_0));
/* 129 */     }
/* 130 */
/* 131 */     Object project_arg_1 = null;
/* 132 */     if (project_isNull_6) {
/* 133 */       project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(null);
/* 134 */     } else {
/* 135 */       project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(project_value_6);
/* 136 */     }
/* 137 */
/* 138 */     UTF8String project_result_1 = null;
/* 139 */     try {
/* 140 */       project_result_1 = (UTF8String)((scala.Function1[]) references[3] /* converters */)[1].apply(((scala.Function1) references[4] /* udf */).apply(project_arg_1)
);
/* 141 */     } catch (Throwable e) {
/* 142 */       throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
/* 143 */         "DataFrameSuite$$Lambda$6430/2004847941", "string", "string", e);
/* 144 */     }
/* 145 */
/* 146 */     boolean project_isNull_5 = project_result_1 == null;
/* 147 */     UTF8String project_value_5 = null;
/* 148 */     if (!project_isNull_5) {
/* 149 */       project_value_5 = project_result_1;
/* 150 */     }
/* 151 */     project_subExprIsNull_0 = project_isNull_5;
/* 152 */     project_mutableStateArray_0[0] = project_value_5;
/* 153 */   }
```

No

Unit test.

Closes apache#32699 from viirya/improve-subexpr.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants