1616// under the License.
1717
1818use arrow:: array:: builder:: { Int32Builder , StringBuilder } ;
19- use arrow:: datatypes:: { DataType , Field , Schema } ;
19+ use arrow:: array:: { Array , ArrayRef , Int32Array } ;
20+ use arrow:: datatypes:: { Field , Schema } ;
2021use arrow:: record_batch:: RecordBatch ;
2122use criterion:: { black_box, criterion_group, criterion_main, Criterion } ;
22- use datafusion_common:: ScalarValue ;
2323use datafusion_expr:: Operator ;
24- use datafusion_physical_expr:: expressions:: { BinaryExpr , CaseExpr , Column , Literal } ;
24+ use datafusion_physical_expr:: expressions:: { case , col , lit , BinaryExpr } ;
2525use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
2626use std:: sync:: Arc ;
2727
28- fn make_col ( name : & str , index : usize ) -> Arc < dyn PhysicalExpr > {
29- Arc :: new ( Column :: new ( name, index) )
28+ fn make_x_cmp_y (
29+ x : & Arc < dyn PhysicalExpr > ,
30+ op : Operator ,
31+ y : i32 ,
32+ ) -> Arc < dyn PhysicalExpr > {
33+ Arc :: new ( BinaryExpr :: new ( Arc :: clone ( x) , op, lit ( y) ) )
3034}
3135
32- fn make_lit_i32 ( n : i32 ) -> Arc < dyn PhysicalExpr > {
33- Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( n) ) ) )
34- }
35-
36- fn criterion_benchmark ( c : & mut Criterion ) {
37- // create input data
36+ fn make_batch ( row_count : usize , column_count : usize ) -> RecordBatch {
3837 let mut c1 = Int32Builder :: new ( ) ;
3938 let mut c2 = StringBuilder :: new ( ) ;
4039 let mut c3 = StringBuilder :: new ( ) ;
41- for i in 0 ..1000 {
42- c1. append_value ( i) ;
40+ for i in 0 ..row_count {
41+ c1. append_value ( i as i32 ) ;
4342 if i % 7 == 0 {
4443 c2. append_null ( ) ;
4544 } else {
@@ -54,69 +53,148 @@ fn criterion_benchmark(c: &mut Criterion) {
5453 let c1 = Arc :: new ( c1. finish ( ) ) ;
5554 let c2 = Arc :: new ( c2. finish ( ) ) ;
5655 let c3 = Arc :: new ( c3. finish ( ) ) ;
57- let schema = Schema :: new ( vec ! [
58- Field :: new( "c1" , DataType :: Int32 , true ) ,
59- Field :: new( "c2" , DataType :: Utf8 , true ) ,
60- Field :: new( "c3" , DataType :: Utf8 , true ) ,
61- ] ) ;
62- let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ c1, c2, c3] ) . unwrap ( ) ;
63-
64- // use same predicate for all benchmarks
65- let predicate = Arc :: new ( BinaryExpr :: new (
66- make_col ( "c1" , 0 ) ,
67- Operator :: LtEq ,
68- make_lit_i32 ( 500 ) ,
69- ) ) ;
56+ let mut columns: Vec < ArrayRef > = vec ! [ c1, c2, c3] ;
57+ for _ in 3 ..column_count {
58+ columns. push ( Arc :: new ( Int32Array :: from_value ( 0 , row_count) ) ) ;
59+ }
7060
71- // CASE WHEN c1 <= 500 THEN 1 ELSE 0 END
72- c. bench_function ( "case_when: scalar or scalar" , |b| {
73- let expr = Arc :: new (
74- CaseExpr :: try_new (
75- None ,
76- vec ! [ ( predicate. clone( ) , make_lit_i32( 1 ) ) ] ,
77- Some ( make_lit_i32 ( 0 ) ) ,
61+ let fields = columns
62+ . iter ( )
63+ . enumerate ( )
64+ . map ( |( i, c) | {
65+ Field :: new (
66+ format ! ( "c{}" , i + 1 ) ,
67+ c. data_type ( ) . clone ( ) ,
68+ c. is_nullable ( ) ,
7869 )
79- . unwrap ( ) ,
80- ) ;
81- b. iter ( || black_box ( expr. evaluate ( black_box ( & batch) ) . unwrap ( ) ) )
82- } ) ;
70+ } )
71+ . collect :: < Vec < _ > > ( ) ;
8372
84- // CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END
85- c. bench_function ( "case_when: column or null" , |b| {
86- let expr = Arc :: new (
87- CaseExpr :: try_new ( None , vec ! [ ( predicate. clone( ) , make_col( "c2" , 1 ) ) ] , None )
73+ let schema = Arc :: new ( Schema :: new ( fields) ) ;
74+ RecordBatch :: try_new ( Arc :: clone ( & schema) , columns) . unwrap ( )
75+ }
76+
77+ fn criterion_benchmark ( c : & mut Criterion ) {
78+ run_benchmarks ( c, & make_batch ( 8192 , 3 ) ) ;
79+ run_benchmarks ( c, & make_batch ( 8192 , 50 ) ) ;
80+ run_benchmarks ( c, & make_batch ( 8192 , 100 ) ) ;
81+ }
82+
83+ fn run_benchmarks ( c : & mut Criterion , batch : & RecordBatch ) {
84+ let c1 = col ( "c1" , & batch. schema ( ) ) . unwrap ( ) ;
85+ let c2 = col ( "c2" , & batch. schema ( ) ) . unwrap ( ) ;
86+ let c3 = col ( "c3" , & batch. schema ( ) ) . unwrap ( ) ;
87+
88+ // No expression, when/then/else, literal values
89+ c. bench_function (
90+ format ! (
91+ "case_when {}x{}: CASE WHEN c1 <= 500 THEN 1 ELSE 0 END" ,
92+ batch. num_rows( ) ,
93+ batch. num_columns( )
94+ )
95+ . as_str ( ) ,
96+ |b| {
97+ let expr = Arc :: new (
98+ case (
99+ None ,
100+ vec ! [ ( make_x_cmp_y( & c1, Operator :: LtEq , 500 ) , lit( 1 ) ) ] ,
101+ Some ( lit ( 0 ) ) ,
102+ )
88103 . unwrap ( ) ,
89- ) ;
90- b. iter ( || black_box ( expr. evaluate ( black_box ( & batch) ) . unwrap ( ) ) )
91- } ) ;
104+ ) ;
105+ b. iter ( || black_box ( expr. evaluate ( black_box ( batch) ) . unwrap ( ) ) )
106+ } ,
107+ ) ;
108+
109+ // No expression, when/then/else, column reference values
110+ c. bench_function (
111+ format ! (
112+ "case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 ELSE c3 END" ,
113+ batch. num_rows( ) ,
114+ batch. num_columns( )
115+ )
116+ . as_str ( ) ,
117+ |b| {
118+ let expr = Arc :: new (
119+ case (
120+ None ,
121+ vec ! [ ( make_x_cmp_y( & c1, Operator :: LtEq , 500 ) , Arc :: clone( & c2) ) ] ,
122+ Some ( Arc :: clone ( & c3) ) ,
123+ )
124+ . unwrap ( ) ,
125+ ) ;
126+ b. iter ( || black_box ( expr. evaluate ( black_box ( batch) ) . unwrap ( ) ) )
127+ } ,
128+ ) ;
92129
93- // CASE WHEN c1 <= 500 THEN c2 ELSE c3 END
94- c. bench_function ( "case_when: expr or expr" , |b| {
130+ // No expression, when/then, implicit else
131+ c. bench_function (
132+ format ! (
133+ "case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END" ,
134+ batch. num_rows( ) ,
135+ batch. num_columns( )
136+ )
137+ . as_str ( ) ,
138+ |b| {
139+ let expr = Arc :: new (
140+ case (
141+ None ,
142+ vec ! [ ( make_x_cmp_y( & c1, Operator :: LtEq , 500 ) , Arc :: clone( & c2) ) ] ,
143+ None ,
144+ )
145+ . unwrap ( ) ,
146+ ) ;
147+ b. iter ( || black_box ( expr. evaluate ( black_box ( batch) ) . unwrap ( ) ) )
148+ } ,
149+ ) ;
150+
151+ // With expression, two when/then branches
152+ c. bench_function (
153+ format ! (
154+ "case_when {}x{}: CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END" ,
155+ batch. num_rows( ) ,
156+ batch. num_columns( )
157+ )
158+ . as_str ( ) ,
159+ |b| {
160+ let expr = Arc :: new (
161+ case (
162+ Some ( Arc :: clone ( & c1) ) ,
163+ vec ! [ ( lit( 1 ) , Arc :: clone( & c2) ) , ( lit( 2 ) , Arc :: clone( & c3) ) ] ,
164+ None ,
165+ )
166+ . unwrap ( ) ,
167+ ) ;
168+ b. iter ( || black_box ( expr. evaluate ( black_box ( batch) ) . unwrap ( ) ) )
169+ } ,
170+ ) ;
171+
172+ // Many when/then branches where all are effectively reachable
173+ c. bench_function ( format ! ( "case_when {}x{}: CASE WHEN c1 == 0 THEN 0 WHEN c1 == 1 THEN 1 ... WHEN c1 == n THEN n ELSE n + 1 END" , batch. num_rows( ) , batch. num_columns( ) ) . as_str ( ) , |b| {
174+ let when_thens = ( 0 ..batch. num_rows ( ) as i32 ) . map ( |i| ( make_x_cmp_y ( & c1, Operator :: Eq , i) , lit ( i) ) ) . collect ( ) ;
95175 let expr = Arc :: new (
96- CaseExpr :: try_new (
176+ case (
97177 None ,
98- vec ! [ ( predicate . clone ( ) , make_col ( "c2" , 1 ) ) ] ,
99- Some ( make_col ( "c3" , 2 ) ) ,
178+ when_thens ,
179+ Some ( lit ( batch . num_rows ( ) as i32 ) )
100180 )
101- . unwrap ( ) ,
181+ . unwrap ( ) ,
102182 ) ;
103- b. iter ( || black_box ( expr. evaluate ( black_box ( & batch) ) . unwrap ( ) ) )
183+ b. iter ( || black_box ( expr. evaluate ( black_box ( batch) ) . unwrap ( ) ) )
104184 } ) ;
105185
106- // CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END
107- c. bench_function ( "case_when: CASE expr" , |b| {
186+ // Many when/then branches where all but the first few are effectively unreachable
187+ c. bench_function ( format ! ( "case_when {}x{}: CASE WHEN c1 < 0 THEN 0 WHEN c1 < 1000 THEN 1 ... WHEN c1 < n * 1000 THEN n ELSE n + 1 END" , batch. num_rows( ) , batch. num_columns( ) ) . as_str ( ) , |b| {
188+ let when_thens = ( 0 ..batch. num_rows ( ) as i32 ) . map ( |i| ( make_x_cmp_y ( & c1, Operator :: Eq , i * 1000 ) , lit ( i) ) ) . collect ( ) ;
108189 let expr = Arc :: new (
109- CaseExpr :: try_new (
110- Some ( make_col ( "c1" , 0 ) ) ,
111- vec ! [
112- ( make_lit_i32( 1 ) , make_col( "c2" , 1 ) ) ,
113- ( make_lit_i32( 2 ) , make_col( "c3" , 2 ) ) ,
114- ] ,
190+ case (
115191 None ,
192+ when_thens,
193+ Some ( lit ( batch. num_rows ( ) as i32 ) )
116194 )
117- . unwrap ( ) ,
195+ . unwrap ( ) ,
118196 ) ;
119- b. iter ( || black_box ( expr. evaluate ( black_box ( & batch) ) . unwrap ( ) ) )
197+ b. iter ( || black_box ( expr. evaluate ( black_box ( batch) ) . unwrap ( ) ) )
120198 } ) ;
121199}
122200
0 commit comments