Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 1 addition & 85 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filt
use arrow::datatypes::{DataType, Schema};
use datafusion_common::cast::as_boolean_array;
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
exec_err, internal_datafusion_err, internal_err, Result, ScalarValue,
};
use datafusion_expr::ColumnarValue;

Expand Down Expand Up @@ -62,11 +62,6 @@ enum EvalMethod {
/// are literal values
/// CASE WHEN condition THEN literal ELSE literal END
ScalarOrScalar,
/// This is a specialization for a specific use case where we can take a fast path
/// if there is just one when/then pair and both the `then` and `else` are expressions
///
/// CASE WHEN condition THEN expression ELSE expression END
ExpressionOrExpression,
}

/// The CASE expression is similar to a series of nested if/else and there are two forms that
Expand Down Expand Up @@ -156,8 +151,6 @@ impl CaseExpr {
&& else_expr.as_ref().unwrap().as_any().is::<Literal>()
{
EvalMethod::ScalarOrScalar
} else if when_then_expr.len() == 1 && else_expr.is_some() {
EvalMethod::ExpressionOrExpression
} else {
EvalMethod::NoExpression
};
Expand Down Expand Up @@ -407,43 +400,6 @@ impl CaseExpr {
let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
}

fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;

// evalute when condition on batch
let when_value = self.when_then_expr[0].0.evaluate(batch)?;
let when_value = when_value.into_array(batch.num_rows())?;
let when_value = as_boolean_array(&when_value).map_err(|e| {
DataFusionError::Context(
"WHEN expression did not return a BooleanArray".to_string(),
Box::new(e),
)
})?;

// Treat 'NULL' as false value
let when_value = match when_value.null_count() {
0 => Cow::Borrowed(when_value),
_ => Cow::Owned(prep_null_mask_filter(when_value)),
};

let then_value = self.when_then_expr[0]
.1
.evaluate_selection(batch, &when_value)?
.into_array(batch.num_rows())?;

// evaluate else expression on the values not covered by when_value
let remainder = not(&when_value)?;
let e = self.else_expr.as_ref().unwrap();
// keep `else_expr`'s data type and return type consistent
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
.unwrap_or_else(|_| Arc::clone(e));
let else_ = expr
.evaluate_selection(batch, &remainder)?
.into_array(batch.num_rows())?;

Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
}
}

impl PhysicalExpr for CaseExpr {
Expand Down Expand Up @@ -507,7 +463,6 @@ impl PhysicalExpr for CaseExpr {
self.case_column_or_null(batch)
}
EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
}
}

Expand Down Expand Up @@ -1296,45 +1251,6 @@ mod tests {
Ok(())
}

#[test]
fn test_expr_or_expr_specialization() -> Result<()> {
let batch = case_test_batch1()?;
let schema = batch.schema();
let when = binary(
col("a", &schema)?,
Operator::LtEq,
lit(2i32),
&batch.schema(),
)?;
let then = binary(
col("a", &schema)?,
Operator::Plus,
lit(1i32),
&batch.schema(),
)?;
let else_expr = binary(
col("a", &schema)?,
Operator::Minus,
lit(1i32),
&batch.schema(),
)?;
let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
assert!(matches!(
expr.eval_method,
EvalMethod::ExpressionOrExpression
));
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_int32_array(&result).expect("failed to downcast to Int32Array");

let expected = &Int32Array::from(vec![Some(2), Some(1), None, Some(4)]);

assert_eq!(expected, result);
Ok(())
}

fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
}
Expand Down
11 changes: 11 additions & 0 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,18 @@ FROM t;
----
[{foo: blarg}]

query II
SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM (VALUES (0), (1), (2)) t(v)
----
0 42
1 10
2 5

query II
SELECT v, CASE WHEN v < 0 THEN 10/0 ELSE 1 END FROM (VALUES (1), (2)) t(v)
----
1 1
2 1

statement ok
drop table t