Skip to content

Commit

Permalink
case when: support result type coerced
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Jul 1, 2022
1 parent 13d538c commit 50fd76b
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 16 deletions.
154 changes: 142 additions & 12 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::binary_rule::comparison_eq_coercion;
use datafusion_expr::ColumnarValue;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
Expand Down Expand Up @@ -76,7 +77,7 @@ impl CaseExpr {
/// Create a new CASE WHEN expression
pub fn try_new(
expr: Option<Arc<dyn PhysicalExpr>>,
when_then_expr: &[WhenThen],
when_then_expr: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Result<Self> {
if when_then_expr.is_empty() {
Expand All @@ -86,7 +87,7 @@ impl CaseExpr {
} else {
Ok(Self {
expr,
when_then_expr: when_then_expr.to_vec(),
when_then_expr,
else_expr,
})
}
Expand Down Expand Up @@ -291,14 +292,67 @@ impl PhysicalExpr for CaseExpr {
/// Create a CASE expression
pub fn case(
expr: Option<Arc<dyn PhysicalExpr>>,
when_thens: &[WhenThen],
when_thens: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
// all the result of then and else should be convert to a common data type,
// if they can be coercible to a common data type, return error.
let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema);
let (when_thens, else_expr) = match coerce_type {
None => Err(DataFusionError::Plan(format!(
"Can't get a common type for then {:?} and else {:?} expression",
when_thens, else_expr
))),
Some(data_type) => {
// cast then expr
let left = when_thens
.into_iter()
.map(|(when, then)| {
let then = try_cast(then, input_schema, data_type.clone()).unwrap();
(when, then)
})
.collect::<Vec<WhenThen>>();
let right = match else_expr {
None => None,
Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
};
Ok((left, right))
}
}?;

Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
}

fn get_case_common_type(
when_thens: &[WhenThen],
else_expr: Option<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Option<DataType> {
let thens_type = when_thens
.iter()
.map(|when_then| {
let data_type = &when_then.1.data_type(input_schema).unwrap();
data_type.clone()
})
.collect::<Vec<_>>();
let else_type = match else_expr {
None => {
// case when then exprs must have one then value
thens_type[0].clone()
}
Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
};
thens_type
.iter()
.fold(Some(else_type), |left, right_type| match left {
None => None,
// TODO: now just use the `equal` coercion rule for case when. If find the issue, and
// refactor again.
Some(left_type) => comparison_eq_coercion(&left_type, right_type),
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -325,8 +379,9 @@ mod tests {

let expr = case(
Some(col("a", &schema)?),
&[(when1, then1), (when2, then2)],
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
Expand Down Expand Up @@ -355,8 +410,9 @@ mod tests {

let expr = case(
Some(col("a", &schema)?),
&[(when1, then1), (when2, then2)],
vec![(when1, then1), (when2, then2)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
Expand Down Expand Up @@ -389,8 +445,9 @@ mod tests {

let expr = case(
Some(col("a", &schema)?),
&[(when1, then1)],
vec![(when1, then1)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
Expand Down Expand Up @@ -426,7 +483,12 @@ mod tests {
)?;
let then2 = lit(ScalarValue::Int32(Some(456)));

let expr = case(None, &[(when1, then1), (when2, then2)], None)?;
let expr = case(
None,
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -460,7 +522,7 @@ mod tests {
)?;
let x = lit(ScalarValue::Float64(None));

let expr = case(None, &[(when1, then1)], Some(x))?;
let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -503,7 +565,12 @@ mod tests {
let then2 = lit(ScalarValue::Int32(Some(456)));
let else_value = lit(ScalarValue::Int32(Some(999)));

let expr = case(None, &[(when1, then1), (when2, then2)], Some(else_value))?;
let expr = case(
None,
vec![(when1, then1), (when2, then2)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -533,7 +600,7 @@ mod tests {
let then = lit(ScalarValue::Float64(Some(123.3)));
let else_value = lit(ScalarValue::Int32(Some(999)));

let expr = case(None, &[(when, then)], Some(else_value))?;
let expr = case(None, vec![(when, then)], Some(else_value), schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -562,7 +629,7 @@ mod tests {
)?;
let then = col("load4", &schema)?;

let expr = case(None, &[(when, then)], None)?;
let expr = case(None, vec![(when, then)], None, schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand All @@ -587,7 +654,7 @@ mod tests {
let when = lit(ScalarValue::Float64(Some(1.77)));
let then = col("load4", &schema)?;

let expr = case(Some(expr), &[(when, then)], None)?;
let expr = case(Some(expr), vec![(when, then)], None, schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -637,4 +704,67 @@ mod tests {
RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
Ok(batch)
}

#[test]
fn case_test_incompatible() -> Result<()> {
// 1 then is int64
// 2 then is boolean
let batch = case_test_batch()?;
let schema = batch.schema();

// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("foo".to_string()))),
&batch.schema(),
)?;
let then1 = lit(ScalarValue::Int32(Some(123)));
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("bar".to_string()))),
&batch.schema(),
)?;
let then2 = lit(ScalarValue::Boolean(Some(true)));

let expr = case(
None,
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
);
assert!(expr.is_err());

// then 1 is int32
// then 2 is int64
// else is float
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("foo".to_string()))),
&batch.schema(),
)?;
let then1 = lit(ScalarValue::Int32(Some(123)));
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("bar".to_string()))),
&batch.schema(),
)?;
let then2 = lit(ScalarValue::Int64(Some(456)));
let else_expr = lit(ScalarValue::Float64(Some(1.23)));

let expr = case(
None,
vec![(when1, then1), (when2, then2)],
Some(else_expr),
schema.as_ref(),
);
assert!(expr.is_ok());
let result_type = expr.unwrap().data_type(schema.as_ref())?;
assert_eq!(DataType::Float64, result_type);
Ok(())
}
}
11 changes: 7 additions & 4 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
use crate::expressions::try_cast;
use crate::{
execution_props::ExecutionProps,
expressions::{
self, binary, CaseExpr, Column, DateIntervalExpr, GetIndexedFieldExpr, Literal,
},
expressions::{self, binary, Column, DateIntervalExpr, GetIndexedFieldExpr, Literal},
functions, udf,
var_provider::VarType,
PhysicalExpr,
Expand Down Expand Up @@ -162,7 +160,12 @@ pub fn create_physical_expr(
} else {
None
};
Ok(expressions::case(expr, &when_then_expr, else_expr)?)
Ok(expressions::case(
expr,
when_then_expr,
else_expr,
input_schema,
)?)
}
Expr::Cast { expr, data_type } => expressions::cast(
create_physical_expr(expr, input_dfschema, input_schema, execution_props)?,
Expand Down

0 comments on commit 50fd76b

Please sign in to comment.