From 1944c46a3e3ef2540cb6f4a6788221f50d32555b Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Fri, 24 Oct 2025 22:52:11 +0200 Subject: [PATCH 1/8] Project record batches to avoid filtering unused columns --- .../physical-expr/src/expressions/case.rs | 382 +++++++++++++----- 1 file changed, 272 insertions(+), 110 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 0b4c3af1d9c5..b897bfb4c900 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -27,8 +27,10 @@ use arrow::compute::{ use arrow::datatypes::{DataType, Schema, UInt32Type}; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, + exec_err, internal_datafusion_err, internal_err, DataFusionError, HashMap, HashSet, + Result, ScalarValue, }; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::compare_with_eq; @@ -46,13 +48,13 @@ enum EvalMethod { /// [WHEN ...] /// [ELSE result] /// END - NoExpression, + NoExpression(ProjectedCaseBody), /// CASE expression /// WHEN value THEN result /// [WHEN ...] /// [ELSE result] /// END - WithExpression, + WithExpression(ProjectedCaseBody), /// This is a specialization for a specific use case where we can take a fast path /// for expressions that are infallible and can be cheaply computed for the entire /// record batch rather than just for the rows where the predicate is true. @@ -68,7 +70,106 @@ enum EvalMethod { /// if there is just one when/then pair and both the `then` and `else` are expressions /// /// CASE WHEN condition THEN expression ELSE expression END - ExpressionOrExpression, + ExpressionOrExpression(ProjectedCaseBody), +} + +/// The body of a CASE expression which consists of an optional base expression, the "when/then" +/// branches and an optional "else" branch. +#[derive(Debug, Hash, PartialEq, Eq)] +struct CaseBody { + /// Optional base expression that can be compared to literal values in the "when" expressions + expr: Option>, + /// One or more when/then expressions + when_then_expr: Vec, + /// Optional "else" expression + else_expr: Option>, +} + +impl CaseBody { + /// Derives a [ProjectedCaseBody] from this [CaseBody]. + fn project(&self) -> Result { + // Determine the set of columns that are used in all the expressions of the case body. + let mut used_column_indices = HashSet::::new(); + let mut collect_column_indices = |expr: &Arc| { + expr.apply(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + used_column_indices.insert(column.index()); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("Closure cannot fail"); + }; + + if let Some(e) = &self.expr { + collect_column_indices(e); + } + self.when_then_expr.iter().for_each(|(w, t)| { + collect_column_indices(w); + collect_column_indices(t); + }); + if let Some(e) = &self.else_expr { + collect_column_indices(e); + } + + // Construct a mapping from the original column index to the projected column index. + let mut column_index_map = HashMap::::new(); + used_column_indices + .iter() + .enumerate() + .for_each(|(projected, original)| { + column_index_map.insert(*original, projected); + }); + + // Construct the projected body by rewriting each expression from the original body + // using the column index mapping. + let project = |expr: &Arc| -> Result> { + Arc::clone(expr) + .transform_down(|e| { + if let Some(column) = e.as_any().downcast_ref::() { + let original = column.index(); + let projected = *column_index_map.get(&original).unwrap(); + if projected != original { + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + projected, + )))); + } + } + Ok(Transformed::no(e)) + }) + .map(|t| t.data) + }; + + let projected_body = CaseBody { + expr: self.expr.as_ref().map(project).transpose()?, + when_then_expr: self + .when_then_expr + .iter() + .map(|(e, t)| Ok((project(e)?, project(t)?))) + .collect::>>()?, + else_expr: self.else_expr.as_ref().map(project).transpose()?, + }; + + // Construct the projection vector + let projection = column_index_map + .iter() + .sorted_by_key(|(_, v)| **v) + .map(|(k, _)| *k) + .collect::>(); + + Ok(ProjectedCaseBody { + projection, + body: projected_body, + }) + } +} + +/// A derived case body that can be used to evaluate a case expression after projecting +/// record batches using a projection vector. +#[derive(Debug, Hash, PartialEq, Eq)] +struct ProjectedCaseBody { + projection: Vec, + body: CaseBody, } /// The CASE expression is similar to a series of nested if/else and there are two forms that @@ -90,12 +191,8 @@ enum EvalMethod { /// END #[derive(Debug, Hash, PartialEq, Eq)] pub struct CaseExpr { - /// Optional base expression that can be compared to literal values in the "when" expressions - expr: Option>, - /// One or more when/then expressions - when_then_expr: Vec, - /// Optional "else" expression - else_expr: Option>, + /// The case expression body + body: CaseBody, /// Evaluation method to use eval_method: EvalMethod, } @@ -103,13 +200,13 @@ pub struct CaseExpr { impl std::fmt::Display for CaseExpr { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!(f, "CASE ")?; - if let Some(e) = &self.expr { + if let Some(e) = &self.body.expr { write!(f, "{e} ")?; } - for (w, t) in &self.when_then_expr { + for (w, t) in &self.body.when_then_expr { write!(f, "WHEN {w} THEN {t} ")?; } - if let Some(e) = &self.else_expr { + if let Some(e) = &self.body.else_expr { write!(f, "ELSE {e} ")?; } write!(f, "END") @@ -556,63 +653,61 @@ impl CaseExpr { }; if when_then_expr.is_empty() { - exec_err!("There must be at least one WHEN clause") - } else { - let eval_method = if expr.is_some() { - EvalMethod::WithExpression - } else if when_then_expr.len() == 1 - && is_cheap_and_infallible(&(when_then_expr[0].1)) - && else_expr.is_none() - { - EvalMethod::InfallibleExprOrNull - } else if when_then_expr.len() == 1 - && when_then_expr[0].1.as_any().is::() - && else_expr.is_some() - && else_expr.as_ref().unwrap().as_any().is::() - { - EvalMethod::ScalarOrScalar - } else if when_then_expr.len() == 1 && else_expr.is_some() { - EvalMethod::ExpressionOrExpression - } else { - EvalMethod::NoExpression - }; - - Ok(Self { - expr, - when_then_expr, - else_expr, - eval_method, - }) + return exec_err!("There must be at least one WHEN clause"); } + + let body = CaseBody { + expr, + when_then_expr, + else_expr, + }; + + let eval_method = if body.expr.is_some() { + EvalMethod::WithExpression(body.project()?) + } else if body.when_then_expr.len() == 1 + && is_cheap_and_infallible(&(body.when_then_expr[0].1)) + && body.else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else if body.when_then_expr.len() == 1 + && body.when_then_expr[0].1.as_any().is::() + && body.else_expr.is_some() + && body.else_expr.as_ref().unwrap().as_any().is::() + { + EvalMethod::ScalarOrScalar + } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() { + EvalMethod::ExpressionOrExpression(body.project()?) + } else { + EvalMethod::NoExpression(body.project()?) + }; + + Ok(Self { body, eval_method }) } /// Optional base expression that can be compared to literal values in the "when" expressions pub fn expr(&self) -> Option<&Arc> { - self.expr.as_ref() + self.body.expr.as_ref() } /// One or more when/then expressions pub fn when_then_expr(&self) -> &[WhenThen] { - &self.when_then_expr + &self.body.when_then_expr } /// Optional "else" expression pub fn else_expr(&self) -> Option<&Arc> { - self.else_expr.as_ref() + self.body.else_expr.as_ref() } } -impl CaseExpr { - /// This function evaluates the form of CASE that matches an expression to fixed values. - /// - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows()); +impl CaseBody { + /// See [CaseExpr::case_when_with_expr]. + fn case_when_with_expr( + &self, + batch: &RecordBatch, + return_type: &DataType, + ) -> Result { + let mut result_builder = ResultBuilder::new(return_type, batch.num_rows()); // `remainder_rows` contains the indices of the rows that need to be evaluated let mut remainder_rows: ArrayRef = @@ -641,7 +736,7 @@ impl CaseExpr { // If there is an else expression, use that as the default value for the null rows // Otherwise the default `null` value from the result builder will be used. - if let Some(e) = self.else_expr() { + if let Some(e) = &self.else_expr { let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; if base_all_null { @@ -744,7 +839,7 @@ impl CaseExpr { // If we reached this point, some rows were left unmatched. // Check if those need to be evaluated using the 'else' expression. - if let Some(e) = self.else_expr() { + if let Some(e) = &self.else_expr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; @@ -754,16 +849,13 @@ impl CaseExpr { result_builder.finish() } - /// This function evaluates the form of CASE where each WHEN expression is a boolean - /// expression. - /// - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows()); + /// See [CaseExpr::case_when_no_expr]. + fn case_when_no_expr( + &self, + batch: &RecordBatch, + return_type: &DataType, + ) -> Result { + let mut result_builder = ResultBuilder::new(return_type, batch.num_rows()); // `remainder_rows` contains the indices of the rows that need to be evaluated let mut remainder_rows: ArrayRef = @@ -835,7 +927,7 @@ impl CaseExpr { // If we reached this point, some rows were left unmatched. // Check if those need to be evaluated using the 'else' expression. - if let Some(e) = self.else_expr() { + if let Some(e) = &self.else_expr { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_value = expr.evaluate(&remainder_batch)?; @@ -845,6 +937,79 @@ impl CaseExpr { result_builder.finish() } + /// See [CaseExpr::expr_or_expr]. + fn expr_or_expr( + &self, + batch: &RecordBatch, + when_value: &BooleanArray, + return_type: &DataType, + ) -> Result { + let then_value = self.when_then_expr[0] + .1 + .evaluate_selection(batch, when_value)? + .into_array(batch.num_rows())?; + + // evaluate else expression on the values not covered by when_value + let remainder = not(when_value)?; + let e = self.else_expr.as_ref().unwrap(); + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + + Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) + } +} + +impl CaseExpr { + /// This function evaluates the form of CASE that matches an expression to fixed values. + /// + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + fn case_when_with_expr( + &self, + batch: &RecordBatch, + projected: &ProjectedCaseBody, + ) -> Result { + let return_type = self.data_type(&batch.schema())?; + if projected.projection.len() < batch.num_rows() { + let projected_batch = batch.project(&projected.projection)?; + projected + .body + .case_when_with_expr(&projected_batch, &return_type) + } else { + self.body.case_when_with_expr(batch, &return_type) + } + } + + /// This function evaluates the form of CASE where each WHEN expression is a boolean + /// expression. + /// + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + fn case_when_no_expr( + &self, + batch: &RecordBatch, + projected: &ProjectedCaseBody, + ) -> Result { + let return_type = self.data_type(&batch.schema())?; + if projected.projection.len() < batch.num_rows() { + let projected_batch = batch.project(&projected.projection)?; + projected + .body + .case_when_no_expr(&projected_batch, &return_type) + } else { + self.body.case_when_no_expr(batch, &return_type) + } + } + /// This function evaluates the specialized case of: /// /// CASE WHEN condition THEN column @@ -855,8 +1020,8 @@ impl CaseExpr { /// that are infallible because the expression will be evaluated for all /// rows in the input batch. fn case_column_or_null(&self, batch: &RecordBatch) -> Result { - let when_expr = &self.when_then_expr[0].0; - let then_expr = &self.when_then_expr[0].1; + let when_expr = &self.body.when_then_expr[0].0; + let then_expr = &self.body.when_then_expr[0].1; match when_expr.evaluate(batch)? { // WHEN true --> column @@ -896,7 +1061,7 @@ impl CaseExpr { let return_type = self.data_type(&batch.schema())?; // evaluate when expression - let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = self.body.when_then_expr[0].0.evaluate(batch)?; let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|_| { internal_datafusion_err!("WHEN expression did not return a BooleanArray") @@ -909,10 +1074,10 @@ impl CaseExpr { }; // evaluate then_value - let then_value = self.when_then_expr[0].1.evaluate(batch)?; + let then_value = self.body.when_then_expr[0].1.evaluate(batch)?; let then_value = Scalar::new(then_value.into_array(1)?); - let Some(e) = self.else_expr() else { + let Some(e) = &self.body.else_expr else { return internal_err!("expression did not evaluate to an array"); }; // keep `else_expr`'s data type and return type consistent @@ -921,11 +1086,15 @@ impl CaseExpr { Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) } - fn expr_or_expr(&self, batch: &RecordBatch) -> Result { + fn expr_or_expr( + &self, + batch: &RecordBatch, + projected: &ProjectedCaseBody, + ) -> Result { let return_type = self.data_type(&batch.schema())?; // evaluate when condition on batch - let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = self.body.when_then_expr[0].0.evaluate(batch)?; let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|e| { DataFusionError::Context( @@ -939,9 +1108,9 @@ impl CaseExpr { // entirely anyway. let true_count = when_value.true_count(); if true_count == batch.num_rows() { - return self.when_then_expr[0].1.evaluate(batch); + return self.body.when_then_expr[0].1.evaluate(batch); } else if true_count == 0 { - return self.else_expr.as_ref().unwrap().evaluate(batch); + return self.body.else_expr.as_ref().unwrap().evaluate(batch); } // Treat 'NULL' as false value @@ -950,22 +1119,14 @@ impl CaseExpr { _ => Cow::Owned(prep_null_mask_filter(when_value)), }; - let then_value = self.when_then_expr[0] - .1 - .evaluate_selection(batch, &when_value)? - .into_array(batch.num_rows())?; - - // evaluate else expression on the values not covered by when_value - let remainder = not(&when_value)?; - let e = self.else_expr.as_ref().unwrap(); - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - - Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) + if projected.projection.len() < batch.num_rows() { + let projected_batch = batch.project(&projected.projection)?; + projected + .body + .expr_or_expr(&projected_batch, &when_value, &return_type) + } else { + self.body.expr_or_expr(batch, &when_value, &return_type) + } } } @@ -979,15 +1140,15 @@ impl PhysicalExpr for CaseExpr { // since all then results have the same data type, we can choose any one as the // return data type except for the null. let mut data_type = DataType::Null; - for i in 0..self.when_then_expr.len() { - data_type = self.when_then_expr[i].1.data_type(input_schema)?; + for i in 0..self.body.when_then_expr.len() { + data_type = self.body.when_then_expr[i].1.data_type(input_schema)?; if !data_type.equals_datatype(&DataType::Null) { break; } } // if all then results are null, we use data type of else expr instead if possible. if data_type.equals_datatype(&DataType::Null) { - if let Some(e) = &self.else_expr { + if let Some(e) = &self.body.else_expr { data_type = e.data_type(input_schema)?; } } @@ -998,13 +1159,14 @@ impl PhysicalExpr for CaseExpr { fn nullable(&self, input_schema: &Schema) -> Result { // this expression is nullable if any of the input expressions are nullable let then_nullable = self + .body .when_then_expr .iter() .map(|(_, t)| t.nullable(input_schema)) .collect::>>()?; if then_nullable.contains(&true) { Ok(true) - } else if let Some(e) = &self.else_expr { + } else if let Some(e) = &self.body.else_expr { e.nullable(input_schema) } else { // CASE produces NULL if there is no `else` expr @@ -1014,37 +1176,37 @@ impl PhysicalExpr for CaseExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - match self.eval_method { - EvalMethod::WithExpression => { + match &self.eval_method { + EvalMethod::WithExpression(p) => { // this use case evaluates "expr" and then compares the values with the "when" // values - self.case_when_with_expr(batch) + self.case_when_with_expr(batch, p) } - EvalMethod::NoExpression => { + EvalMethod::NoExpression(p) => { // The "when" conditions all evaluate to boolean in this use case and can be // arbitrary expressions - self.case_when_no_expr(batch) + self.case_when_no_expr(batch, p) } EvalMethod::InfallibleExprOrNull => { // Specialization for CASE WHEN expr THEN column [ELSE NULL] END self.case_column_or_null(batch) } EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), - EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), + EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p), } } fn children(&self) -> Vec<&Arc> { let mut children = vec![]; - if let Some(expr) = &self.expr { + if let Some(expr) = &self.body.expr { children.push(expr) } - self.when_then_expr.iter().for_each(|(cond, value)| { + self.body.when_then_expr.iter().for_each(|(cond, value)| { children.push(cond); children.push(value); }); - if let Some(else_expr) = &self.else_expr { + if let Some(else_expr) = &self.body.else_expr { children.push(else_expr) } children @@ -1059,7 +1221,7 @@ impl PhysicalExpr for CaseExpr { internal_err!("CaseExpr: Wrong number of children") } else { let (expr, when_then_expr, else_expr) = - match (self.expr().is_some(), self.else_expr().is_some()) { + match (self.expr().is_some(), self.body.else_expr.is_some()) { (true, true) => ( Some(&children[0]), &children[1..children.len() - 1], @@ -1085,12 +1247,12 @@ impl PhysicalExpr for CaseExpr { fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE ")?; - if let Some(e) = &self.expr { + if let Some(e) = &self.body.expr { e.fmt_sql(f)?; write!(f, " ")?; } - for (w, t) in &self.when_then_expr { + for (w, t) in &self.body.when_then_expr { write!(f, "WHEN ")?; w.fmt_sql(f)?; write!(f, " THEN ")?; @@ -1098,7 +1260,7 @@ impl PhysicalExpr for CaseExpr { write!(f, " ")?; } - if let Some(e) = &self.else_expr { + if let Some(e) = &self.body.else_expr { write!(f, "ELSE ")?; e.fmt_sql(f)?; write!(f, " ")?; @@ -1842,7 +2004,7 @@ mod tests { let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?; assert!(matches!( expr.eval_method, - EvalMethod::ExpressionOrExpression + EvalMethod::ExpressionOrExpression(_) )); let result = expr .evaluate(&batch)? From b727249c7f43ad99b8b565ec3211b9250c0ac6e0 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 28 Oct 2025 17:53:21 +0100 Subject: [PATCH 2/8] Integrate suggestions from @rluvaton Co-authored-by: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> --- datafusion/physical-expr/src/expressions/case.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index b897bfb4c900..fd754be443c6 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -112,13 +112,11 @@ impl CaseBody { } // Construct a mapping from the original column index to the projected column index. - let mut column_index_map = HashMap::::new(); - used_column_indices + let column_index_map = used_column_indices .iter() .enumerate() - .for_each(|(projected, original)| { - column_index_map.insert(*original, projected); - }); + .map(|(projected, original)| (*original, projected)) + .collect::>(); // Construct the projected body by rewriting each expression from the original body // using the column index mapping. @@ -977,7 +975,7 @@ impl CaseExpr { projected: &ProjectedCaseBody, ) -> Result { let return_type = self.data_type(&batch.schema())?; - if projected.projection.len() < batch.num_rows() { + if projected.projection.len() < batch.num_columns() { let projected_batch = batch.project(&projected.projection)?; projected .body @@ -1000,7 +998,7 @@ impl CaseExpr { projected: &ProjectedCaseBody, ) -> Result { let return_type = self.data_type(&batch.schema())?; - if projected.projection.len() < batch.num_rows() { + if projected.projection.len() < batch.num_columns() { let projected_batch = batch.project(&projected.projection)?; projected .body @@ -1119,7 +1117,7 @@ impl CaseExpr { _ => Cow::Owned(prep_null_mask_filter(when_value)), }; - if projected.projection.len() < batch.num_rows() { + if projected.projection.len() < batch.num_columns() { let projected_batch = batch.project(&projected.projection)?; projected .body From 9e9958e6bf8ef3ff08940046ac31a4c0be4ea493 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 28 Oct 2025 18:13:53 +0100 Subject: [PATCH 3/8] Formatting --- datafusion/physical-expr/src/expressions/case.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index fd754be443c6..2153b688969a 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -116,7 +116,7 @@ impl CaseBody { .iter() .enumerate() .map(|(projected, original)| (*original, projected)) - .collect::>(); + .collect::>(); // Construct the projected body by rewriting each expression from the original body // using the column index mapping. From 7a29634a35ab7c35296dfc7b3b1538e4a5af3588 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 28 Oct 2025 18:27:19 +0100 Subject: [PATCH 4/8] Add SLTs covering projection in case expressions --- datafusion/sqllogictest/test_files/case.slt | 27 +++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 4eaa87b0b516..9fcc0c469bf3 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -617,3 +617,30 @@ SELECT CASE WHEN a = 0 THEN 'a' WHEN 1 / a = 1 THEN 'b' ELSE 'c' END FROM (VALUE a b c + +# EvalMethod::WithExpression using subset of all selected columns in case expression +query III +SELECT CASE a1 WHEN 1 THEN a1 WHEN 2 THEN a2 WHEN 3 THEN b END, b, c +FROM (SELECT a as a1, a as a2, b, c FROM (VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300)) t(a, b, c)); +---- +1 10 100 +2 20 200 +30 30 300 + +# EvalMethod::NoExpression using subset of all selected columns in case expression +query III +SELECT CASE WHEN a1 = 1 THEN a2 WHEN a2 = 2 THEN a1 WHEN 3 THEN b END, b, c +FROM (SELECT a as a1, a as a2, b, c FROM (VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300)) t(a, b, c)); +---- +1 10 100 +2 20 200 +30 30 300 + +# EvalMethod::ExpressionOrExpression using subset of all selected columns in case expression +query III +SELECT CASE WHEN a1 = 1 THEN a2 ELSE b END, b, c +FROM (SELECT a as a1, a as a2, b, c FROM (VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300)) t(a, b, c)); +---- +1 10 100 +20 20 200 +30 30 300 From 9451303f4a83018a730506ba17ac96e8ad020cba Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 28 Oct 2025 18:29:02 +0100 Subject: [PATCH 5/8] Add SLTs covering 'no projection' in case expressions --- datafusion/sqllogictest/test_files/case.slt | 27 +++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 9fcc0c469bf3..849777f572bb 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -644,3 +644,30 @@ FROM (SELECT a as a1, a as a2, b, c FROM (VALUES (1, 10, 100), (2, 20, 200), (3, 1 10 100 20 20 200 30 30 300 + +# EvalMethod::WithExpression using all selected columns in case expression +query I +SELECT CASE a1 WHEN 1 THEN a1 WHEN 2 THEN a2 WHEN 3 THEN NULL END +FROM (SELECT a as a1, a as a2, b, c FROM (VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300)) t(a, b, c)); +---- +1 +2 +NULL + +# EvalMethod::NoExpression using all selected columns in case expression +query I +SELECT CASE WHEN a1 = 1 THEN a2 WHEN a2 = 2 THEN a1 WHEN 3 THEN NULL END +FROM (SELECT a as a1, a as a2, b, c FROM (VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300)) t(a, b, c)); +---- +1 +2 +NULL + +# EvalMethod::ExpressionOrExpression using all selected columns in case expression +query I +SELECT CASE WHEN a1 = 1 THEN a2 ELSE NULL END +FROM (SELECT a as a1, a as a2 FROM (VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300)) t(a, b, c)); +---- +1 +NULL +NULL From 7c71790022866eeb81fb2885eeea2fd76e55c91a Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 28 Oct 2025 18:36:42 +0100 Subject: [PATCH 6/8] Add SLTs covering nested project case --- datafusion/sqllogictest/test_files/case.slt | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 849777f572bb..1a4b6a7a2b4a 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -671,3 +671,15 @@ FROM (SELECT a as a1, a as a2 FROM (VALUES (1, 10, 100), (2, 20, 200), (3, 30, 3 1 NULL NULL + +# Nested case with projection +query III +SELECT CASE WHEN a = -1 THEN b WHEN a = -2 THEN -b END, b, c +FROM ( + SELECT b, c, CASE WHEN a1 = 1 THEN -a2 WHEN a1 = 2 THEN -a1 END as a + FROM (SELECT a as a1, a as a2, b, c FROM (VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300)) t(a, b, c)) +); +---- +10 10 100 +-20 20 200 +NULL 30 300 From 924c99c1c4742ca471630a61d47c2db7b7149044 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Wed, 29 Oct 2025 22:55:39 +0100 Subject: [PATCH 7/8] Add comment clarifying the rationale behind ProjectedCaseBody Co-authored-by: Andrew Lamb --- .../physical-expr/src/expressions/case.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 2153b688969a..eb18184c9f79 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -164,6 +164,23 @@ impl CaseBody { /// A derived case body that can be used to evaluate a case expression after projecting /// record batches using a projection vector. +/// +/// This is used to avoid filtering / copying columns that are not used in the +/// input `RecordBatch` when progressively evaluating a `CASE` expression's +/// remainder batches. +/// +/// For example, if we are evaluating the following case expression that +/// only references columns B and D: +/// +/// ```sql +/// CASE WHEN B > 10 THEN D ELSE NULL END +/// ``` +/// +/// If the input has 4 columns, `[A, B, C, D]` it is wasteful to carry through +/// columns A and C that are not used in the expression. Instead, the projection +/// vector will be `[1, 3]` and the case body +/// will be rewritten to refer to columns `[0, 1]` +/// (i.e. remap references from `B` -> 0`, and `D` -> `1`) #[derive(Debug, Hash, PartialEq, Eq)] struct ProjectedCaseBody { projection: Vec, From 5613fbbfda697dcc48af0a73167522dd1339b31e Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 30 Oct 2025 09:03:41 +0100 Subject: [PATCH 8/8] Add more detail to ProjectedCaseBody documentation. --- .../physical-expr/src/expressions/case.rs | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index eb18184c9f79..bd7158db8bb4 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -165,22 +165,30 @@ impl CaseBody { /// A derived case body that can be used to evaluate a case expression after projecting /// record batches using a projection vector. /// -/// This is used to avoid filtering / copying columns that are not used in the +/// This is used to avoid filtering columns that are not used in the /// input `RecordBatch` when progressively evaluating a `CASE` expression's -/// remainder batches. +/// remainder batches. Filtering these columns is wasteful since for a record +/// batch of `n` rows, filtering requires at worst a copy of `n - 1` values +/// per array. If these filtered values will never be accessed, the time spent +/// producing them is better avoided. /// /// For example, if we are evaluating the following case expression that /// only references columns B and D: /// /// ```sql -/// CASE WHEN B > 10 THEN D ELSE NULL END +/// SELECT CASE WHEN B > 10 THEN D ELSE NULL END FROM (VALUES (...)) T(A, B, C, D) /// ``` /// -/// If the input has 4 columns, `[A, B, C, D]` it is wasteful to carry through -/// columns A and C that are not used in the expression. Instead, the projection -/// vector will be `[1, 3]` and the case body -/// will be rewritten to refer to columns `[0, 1]` -/// (i.e. remap references from `B` -> 0`, and `D` -> `1`) +/// Of the 4 input columns `[A, B, C, D]`, the `CASE` expression only access `B` and `D`. +/// Filtering `A` and `C` would be unnecessary and wasteful. +/// +/// If we only retain columns `B` and `D` using `RecordBatch::project` and the projection vector +/// `[1, 3]`, the indices of these two columns will change to `[0, 1]`. To evaluate the +/// case expression, it will need to be rewritten from `CASE WHEN B@1 > 10 THEN D@3 ELSE NULL END` +/// to `CASE WHEN B@0 > 10 THEN D@1 ELSE NULL END`. +/// +/// The projection vector and the rewritten expression (which only differs from the original in +/// column reference indices) are held in a `ProjectedCaseBody`. #[derive(Debug, Hash, PartialEq, Eq)] struct ProjectedCaseBody { projection: Vec,