diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index eb0886a31e8df..b043b6468472b 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -93,6 +93,7 @@ fn criterion_benchmark(c: &mut Criterion) { run_benchmarks(c, &make_batch(8192, 100)); benchmark_lookup_table_case_when(c, 8192); + benchmark_searched_case_when_many_branches(c, 8192); } fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { @@ -517,5 +518,85 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { } } +/// Benchmark for searched CASE WHEN with many branches (vectorized optimization) +fn benchmark_searched_case_when_many_branches(c: &mut Criterion, batch_size: usize) { + let mut group = c.benchmark_group("searched_case_when_many_branches"); + + for num_branches in [5, 10, 20] { + // Create a batch with a column containing values 0..batch_size + let array = Arc::new(Int32Array::from_iter_values(0..batch_size as i32)); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "c1", + array.data_type().clone(), + false, + )])), + vec![array], + ) + .unwrap(); + + let c1 = col("c1", &batch.schema()).unwrap(); + + // CASE WHEN c1 < threshold1 THEN 'a' WHEN c1 < threshold2 THEN 'b' ... ELSE 'z' END + // Thresholds are evenly distributed so all branches get some rows + let step = batch_size as i32 / num_branches as i32; + let when_thens: Vec<_> = (0..num_branches) + .map(|i| { + let threshold = (i as i32 + 1) * step; + ( + make_x_cmp_y(&c1, Operator::Lt, threshold), + lit(format!("branch_{}", i)), + ) + }) + .collect(); + + let expr = Arc::new(case(None, when_thens, Some(lit("else_branch"))).unwrap()); + + group.bench_function( + format!( + "{} branches, {} rows, conditions with column refs", + num_branches, batch_size + ), + |b| b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())), + ); + + // Also benchmark with column THEN values (not just literals) + let c2 = Arc::new(Int32Array::from_iter_values(0..batch_size as i32)); + let batch_with_c2 = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("c1", c2.data_type().clone(), false), + Field::new("c2", c2.data_type().clone(), false), + ])), + vec![Arc::clone(&c2) as ArrayRef, c2], + ) + .unwrap(); + + let c1 = col("c1", &batch_with_c2.schema()).unwrap(); + let c2_col = col("c2", &batch_with_c2.schema()).unwrap(); + + let when_thens: Vec<_> = (0..num_branches) + .map(|i| { + let threshold = (i as i32 + 1) * step; + ( + make_x_cmp_y(&c1, Operator::Lt, threshold), + Arc::clone(&c2_col), + ) + }) + .collect(); + + let expr = Arc::new(case(None, when_thens, Some(lit(0))).unwrap()); + + group.bench_function( + format!( + "{} branches, {} rows, column THEN values", + num_branches, batch_size + ), + |b| b.iter(|| black_box(expr.evaluate(black_box(&batch_with_c2)).unwrap())), + ); + } + + group.finish(); +} + criterion_group!(benches, criterion_benchmark); criterion_main!(benches); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 758317d3d2798..a791a5aa7f5da 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -81,6 +81,17 @@ enum EvalMethod { /// /// See [`LiteralLookupTable`] for more details WithExprScalarLookupTable(LiteralLookupTable), + + /// Vectorized evaluation for CASE WHEN condition expressions with multiple branches. + /// + /// This optimization evaluates all conditions upfront on the full batch, + /// builds a branch index array, then filters once per branch. + /// + /// **Important**: This changes evaluation semantics - conditions are evaluated + /// even for rows where an earlier condition matched. This is safe when conditions + /// are simple comparisons but may cause issues if conditions can error + /// (e.g., division by zero in a condition). + NoExpressionVectorized(ProjectedCaseBody), } /// Implementing hash so we can use `derive` on [`EvalMethod`]. @@ -661,6 +672,11 @@ impl CaseExpr { EvalMethod::ScalarOrScalar } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() { EvalMethod::ExpressionOrExpression(body.project()?) + } else if body.when_then_expr.len() >= 3 { + // Use vectorized evaluation for 3+ branches + // This evaluates all conditions upfront for better cache locality + // Note: This changes short-circuit semantics for CONDITIONS + EvalMethod::NoExpressionVectorized(body.project()?) } else { EvalMethod::NoExpression(body.project()?) }, @@ -940,6 +956,156 @@ impl CaseBody { result_builder.finish() } + /// Vectorized evaluation for CASE WHEN expressions with multiple branches. + /// + /// Instead of evaluating conditions sequentially on progressively shrinking batches, + /// this evaluates all conditions upfront on the full batch, builds a branch index + /// array indicating which branch matched each row, then filters once per branch. + /// + /// This provides better performance for many branches due to: + /// - Better cache locality (conditions evaluated on same data) + /// - Simpler filter predicates (integer equality vs boolean expressions) + /// - No progressive batch shrinking overhead + /// + /// **Important**: This changes short-circuit semantics for CONDITIONS (not THEN expressions). + /// All conditions are evaluated even for rows where an earlier condition matched. + /// This is safe for simple comparisons but may cause issues if conditions can error. + fn case_when_no_expr_vectorized( + &self, + batch: &RecordBatch, + return_type: &DataType, + ) -> Result { + let num_rows = batch.num_rows(); + let num_branches = self.when_then_expr.len(); + + // else_index is used for rows that don't match any condition + let else_index = num_branches as u32; + + // branch_indices[row] = index of first matching branch, or else_index if none match + let mut branch_indices = vec![else_index; num_rows]; + + // Evaluate all conditions and build branch index array + // We iterate in reverse so earlier conditions overwrite later ones + // (implementing first-match semantics) + for branch_idx in (0..num_branches).rev() { + let when_predicate = &self.when_then_expr[branch_idx].0; + let when_value = when_predicate.evaluate(batch)?.into_array(num_rows)?; + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") + })?; + + // For each true value in the condition, set branch_indices to this branch + let branch_idx_u32 = branch_idx as u32; + match when_value.nulls() { + Some(nulls) => { + // Handle nulls - treat null as false + for (row_idx, (is_true, is_valid)) in + when_value.values().iter().zip(nulls.iter()).enumerate() + { + if is_valid && is_true { + branch_indices[row_idx] = branch_idx_u32; + } + } + } + None => { + // No nulls - just check boolean values + for (row_idx, is_true) in when_value.values().iter().enumerate() { + if is_true { + branch_indices[row_idx] = branch_idx_u32; + } + } + } + } + } + + // Count rows per branch to identify which branches have matches + let mut branch_counts = vec![0usize; num_branches + 1]; // +1 for else + for &branch_idx in &branch_indices { + branch_counts[branch_idx as usize] += 1; + } + + // Check if all rows went to a single branch + for (branch_idx, &count) in branch_counts.iter().enumerate() { + if count == num_rows { + // All rows matched this single branch + if branch_idx == num_branches { + // All rows go to ELSE + if let Some(e) = &self.else_expr { + let expr = try_cast( + Arc::clone(e), + &batch.schema(), + return_type.clone(), + )?; + return expr.evaluate(batch); + } else { + // No else expr, return nulls + return Ok(ColumnarValue::Array(new_null_array( + return_type, + num_rows, + ))); + } + } else { + // All rows matched branch_idx + return self.when_then_expr[branch_idx].1.evaluate(batch); + } + } + } + + let mut result_builder = ResultBuilder::new(return_type, num_rows); + + // Process each branch that has matching rows + for (branch_idx, &count) in branch_counts.iter().enumerate().take(num_branches) { + if count == 0 { + continue; + } + + // Build row indices for this branch + let row_indices: ArrayRef = Arc::new(UInt32Array::from_iter_values( + branch_indices + .iter() + .enumerate() + .filter(|(_, b)| **b == branch_idx as u32) + .map(|(row_idx, _)| row_idx as u32), + )); + + // Create filter predicate for this branch + let filter_values: BooleanArray = branch_indices + .iter() + .map(|&b| b == branch_idx as u32) + .collect(); + let then_filter = create_filter(&filter_values, true); + let then_batch = filter_record_batch(batch, &then_filter)?; + + // Evaluate THEN expression only for matching rows + let then_value = self.when_then_expr[branch_idx].1.evaluate(&then_batch)?; + result_builder.add_branch_result(&row_indices, then_value)?; + } + + // Handle ELSE branch + if branch_counts[num_branches] > 0 + && let Some(e) = &self.else_expr + { + let row_indices: ArrayRef = Arc::new(UInt32Array::from_iter_values( + branch_indices + .iter() + .enumerate() + .filter(|(_, b)| **b == else_index) + .map(|(row_idx, _)| row_idx as u32), + )); + + let filter_values: BooleanArray = + branch_indices.iter().map(|&b| b == else_index).collect(); + let else_filter = create_filter(&filter_values, true); + let else_batch = filter_record_batch(batch, &else_filter)?; + + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let else_value = expr.evaluate(&else_batch)?; + result_builder.add_branch_result(&row_indices, else_value)?; + } + + result_builder.finish() + } + /// See [CaseExpr::expr_or_expr]. fn expr_or_expr( &self, @@ -1037,6 +1203,24 @@ impl CaseExpr { } } + /// Vectorized evaluation for CASE WHEN expressions with multiple branches. + /// See [CaseBody::case_when_no_expr_vectorized] for details. + fn case_when_no_expr_vectorized( + &self, + batch: &RecordBatch, + projected: &ProjectedCaseBody, + ) -> Result { + let return_type = self.data_type(&batch.schema())?; + if projected.projection.len() < batch.num_columns() { + let projected_batch = batch.project(&projected.projection)?; + projected + .body + .case_when_no_expr_vectorized(&projected_batch, &return_type) + } else { + self.body.case_when_no_expr_vectorized(batch, &return_type) + } + } + /// This function evaluates the specialized case of: /// /// CASE WHEN condition THEN column @@ -1259,6 +1443,10 @@ impl PhysicalExpr for CaseExpr { // arbitrary expressions self.case_when_no_expr(batch, p) } + EvalMethod::NoExpressionVectorized(p) => { + // Vectorized evaluation: all conditions evaluated upfront + self.case_when_no_expr_vectorized(batch, p) + } EvalMethod::InfallibleExprOrNull => { // Specialization for CASE WHEN expr THEN column [ELSE NULL] END self.case_column_or_null(batch)