Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions datafusion/physical-expr/benches/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ fn criterion_benchmark(c: &mut Criterion) {
run_benchmarks(c, &make_batch(8192, 100));

benchmark_lookup_table_case_when(c, 8192);
benchmark_searched_case_when_many_branches(c, 8192);
}

fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) {
Expand Down Expand Up @@ -517,5 +518,85 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) {
}
}

/// Benchmark for searched CASE WHEN with many branches (vectorized optimization)
fn benchmark_searched_case_when_many_branches(c: &mut Criterion, batch_size: usize) {
let mut group = c.benchmark_group("searched_case_when_many_branches");

for num_branches in [5, 10, 20] {
// Create a batch with a column containing values 0..batch_size
let array = Arc::new(Int32Array::from_iter_values(0..batch_size as i32));
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"c1",
array.data_type().clone(),
false,
)])),
vec![array],
)
.unwrap();

let c1 = col("c1", &batch.schema()).unwrap();

// CASE WHEN c1 < threshold1 THEN 'a' WHEN c1 < threshold2 THEN 'b' ... ELSE 'z' END
// Thresholds are evenly distributed so all branches get some rows
let step = batch_size as i32 / num_branches as i32;
let when_thens: Vec<_> = (0..num_branches)
.map(|i| {
let threshold = (i as i32 + 1) * step;
(
make_x_cmp_y(&c1, Operator::Lt, threshold),
lit(format!("branch_{}", i)),
)
})
.collect();

let expr = Arc::new(case(None, when_thens, Some(lit("else_branch"))).unwrap());

group.bench_function(
format!(
"{} branches, {} rows, conditions with column refs",
num_branches, batch_size
),
|b| b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())),
);

// Also benchmark with column THEN values (not just literals)
let c2 = Arc::new(Int32Array::from_iter_values(0..batch_size as i32));
let batch_with_c2 = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("c1", c2.data_type().clone(), false),
Field::new("c2", c2.data_type().clone(), false),
])),
vec![Arc::clone(&c2) as ArrayRef, c2],
)
.unwrap();

let c1 = col("c1", &batch_with_c2.schema()).unwrap();
let c2_col = col("c2", &batch_with_c2.schema()).unwrap();

let when_thens: Vec<_> = (0..num_branches)
.map(|i| {
let threshold = (i as i32 + 1) * step;
(
make_x_cmp_y(&c1, Operator::Lt, threshold),
Arc::clone(&c2_col),
)
})
.collect();

let expr = Arc::new(case(None, when_thens, Some(lit(0))).unwrap());

group.bench_function(
format!(
"{} branches, {} rows, column THEN values",
num_branches, batch_size
),
|b| b.iter(|| black_box(expr.evaluate(black_box(&batch_with_c2)).unwrap())),
);
}

group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
188 changes: 188 additions & 0 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ enum EvalMethod {
///
/// See [`LiteralLookupTable`] for more details
WithExprScalarLookupTable(LiteralLookupTable),

/// Vectorized evaluation for CASE WHEN condition expressions with multiple branches.
///
/// This optimization evaluates all conditions upfront on the full batch,
/// builds a branch index array, then filters once per branch.
///
/// **Important**: This changes evaluation semantics - conditions are evaluated
/// even for rows where an earlier condition matched. This is safe when conditions
/// are simple comparisons but may cause issues if conditions can error
/// (e.g., division by zero in a condition).
NoExpressionVectorized(ProjectedCaseBody),
}

/// Implementing hash so we can use `derive` on [`EvalMethod`].
Expand Down Expand Up @@ -661,6 +672,11 @@ impl CaseExpr {
EvalMethod::ScalarOrScalar
} else if body.when_then_expr.len() == 1 && body.else_expr.is_some() {
EvalMethod::ExpressionOrExpression(body.project()?)
} else if body.when_then_expr.len() >= 3 {
// Use vectorized evaluation for 3+ branches
// This evaluates all conditions upfront for better cache locality
// Note: This changes short-circuit semantics for CONDITIONS
EvalMethod::NoExpressionVectorized(body.project()?)
} else {
EvalMethod::NoExpression(body.project()?)
},
Expand Down Expand Up @@ -940,6 +956,156 @@ impl CaseBody {
result_builder.finish()
}

/// Vectorized evaluation for CASE WHEN expressions with multiple branches.
///
/// Instead of evaluating conditions sequentially on progressively shrinking batches,
/// this evaluates all conditions upfront on the full batch, builds a branch index
/// array indicating which branch matched each row, then filters once per branch.
///
/// This provides better performance for many branches due to:
/// - Better cache locality (conditions evaluated on same data)
/// - Simpler filter predicates (integer equality vs boolean expressions)
/// - No progressive batch shrinking overhead
///
/// **Important**: This changes short-circuit semantics for CONDITIONS (not THEN expressions).
/// All conditions are evaluated even for rows where an earlier condition matched.
/// This is safe for simple comparisons but may cause issues if conditions can error.
fn case_when_no_expr_vectorized(
&self,
batch: &RecordBatch,
return_type: &DataType,
) -> Result<ColumnarValue> {
let num_rows = batch.num_rows();
let num_branches = self.when_then_expr.len();

// else_index is used for rows that don't match any condition
let else_index = num_branches as u32;

// branch_indices[row] = index of first matching branch, or else_index if none match
let mut branch_indices = vec![else_index; num_rows];

// Evaluate all conditions and build branch index array
// We iterate in reverse so earlier conditions overwrite later ones
// (implementing first-match semantics)
for branch_idx in (0..num_branches).rev() {
let when_predicate = &self.when_then_expr[branch_idx].0;
let when_value = when_predicate.evaluate(batch)?.into_array(num_rows)?;
let when_value = as_boolean_array(&when_value).map_err(|_| {
internal_datafusion_err!("WHEN expression did not return a BooleanArray")
})?;

// For each true value in the condition, set branch_indices to this branch
let branch_idx_u32 = branch_idx as u32;
match when_value.nulls() {
Some(nulls) => {
// Handle nulls - treat null as false
for (row_idx, (is_true, is_valid)) in
when_value.values().iter().zip(nulls.iter()).enumerate()
{
if is_valid && is_true {
branch_indices[row_idx] = branch_idx_u32;
}
}
}
None => {
// No nulls - just check boolean values
for (row_idx, is_true) in when_value.values().iter().enumerate() {
if is_true {
branch_indices[row_idx] = branch_idx_u32;
}
}
}
}
}

// Count rows per branch to identify which branches have matches
let mut branch_counts = vec![0usize; num_branches + 1]; // +1 for else
for &branch_idx in &branch_indices {
branch_counts[branch_idx as usize] += 1;
}

// Check if all rows went to a single branch
for (branch_idx, &count) in branch_counts.iter().enumerate() {
if count == num_rows {
// All rows matched this single branch
if branch_idx == num_branches {
// All rows go to ELSE
if let Some(e) = &self.else_expr {
let expr = try_cast(
Arc::clone(e),
&batch.schema(),
return_type.clone(),
)?;
return expr.evaluate(batch);
} else {
// No else expr, return nulls
return Ok(ColumnarValue::Array(new_null_array(
return_type,
num_rows,
)));
}
} else {
// All rows matched branch_idx
return self.when_then_expr[branch_idx].1.evaluate(batch);
}
}
}

let mut result_builder = ResultBuilder::new(return_type, num_rows);

// Process each branch that has matching rows
for (branch_idx, &count) in branch_counts.iter().enumerate().take(num_branches) {
if count == 0 {
continue;
}

// Build row indices for this branch
let row_indices: ArrayRef = Arc::new(UInt32Array::from_iter_values(
branch_indices
.iter()
.enumerate()
.filter(|(_, b)| **b == branch_idx as u32)
.map(|(row_idx, _)| row_idx as u32),
));

// Create filter predicate for this branch
let filter_values: BooleanArray = branch_indices
.iter()
.map(|&b| b == branch_idx as u32)
.collect();
let then_filter = create_filter(&filter_values, true);
let then_batch = filter_record_batch(batch, &then_filter)?;

// Evaluate THEN expression only for matching rows
let then_value = self.when_then_expr[branch_idx].1.evaluate(&then_batch)?;
result_builder.add_branch_result(&row_indices, then_value)?;
}

// Handle ELSE branch
if branch_counts[num_branches] > 0
&& let Some(e) = &self.else_expr
{
let row_indices: ArrayRef = Arc::new(UInt32Array::from_iter_values(
branch_indices
.iter()
.enumerate()
.filter(|(_, b)| **b == else_index)
.map(|(row_idx, _)| row_idx as u32),
));

let filter_values: BooleanArray =
branch_indices.iter().map(|&b| b == else_index).collect();
let else_filter = create_filter(&filter_values, true);
let else_batch = filter_record_batch(batch, &else_filter)?;

let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
let else_value = expr.evaluate(&else_batch)?;
result_builder.add_branch_result(&row_indices, else_value)?;
}

result_builder.finish()
}

/// See [CaseExpr::expr_or_expr].
fn expr_or_expr(
&self,
Expand Down Expand Up @@ -1037,6 +1203,24 @@ impl CaseExpr {
}
}

/// Vectorized evaluation for CASE WHEN expressions with multiple branches.
/// See [CaseBody::case_when_no_expr_vectorized] for details.
fn case_when_no_expr_vectorized(
&self,
batch: &RecordBatch,
projected: &ProjectedCaseBody,
) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;
if projected.projection.len() < batch.num_columns() {
let projected_batch = batch.project(&projected.projection)?;
projected
.body
.case_when_no_expr_vectorized(&projected_batch, &return_type)
} else {
self.body.case_when_no_expr_vectorized(batch, &return_type)
}
}

/// This function evaluates the specialized case of:
///
/// CASE WHEN condition THEN column
Expand Down Expand Up @@ -1259,6 +1443,10 @@ impl PhysicalExpr for CaseExpr {
// arbitrary expressions
self.case_when_no_expr(batch, p)
}
EvalMethod::NoExpressionVectorized(p) => {
// Vectorized evaluation: all conditions evaluated upfront
self.case_when_no_expr_vectorized(batch, p)
}
EvalMethod::InfallibleExprOrNull => {
// Specialization for CASE WHEN expr THEN column [ELSE NULL] END
self.case_column_or_null(batch)
Expand Down
Loading