Skip to content

Commit

Permalink
CaseWhen: coerce the all then and else data type to a common data type (
Browse files Browse the repository at this point in the history
#2819)

* case when: coerce to the same data type for the result data type

* case when: support result type coerced

* change usage of literal
  • Loading branch information
liukun4515 authored Jul 4, 2022
1 parent 57f47ab commit da392f4
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 18 deletions.
157 changes: 145 additions & 12 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::binary_rule::comparison_eq_coercion;
use datafusion_expr::ColumnarValue;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
Expand Down Expand Up @@ -76,7 +77,7 @@ impl CaseExpr {
/// Create a new CASE WHEN expression
pub fn try_new(
expr: Option<Arc<dyn PhysicalExpr>>,
when_then_expr: &[WhenThen],
when_then_expr: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Result<Self> {
if when_then_expr.is_empty() {
Expand All @@ -86,7 +87,7 @@ impl CaseExpr {
} else {
Ok(Self {
expr,
when_then_expr: when_then_expr.to_vec(),
when_then_expr,
else_expr,
})
}
Expand Down Expand Up @@ -291,12 +292,68 @@ impl PhysicalExpr for CaseExpr {
/// Create a CASE expression
pub fn case(
expr: Option<Arc<dyn PhysicalExpr>>,
when_thens: &[WhenThen],
when_thens: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
// all the result of then and else should be convert to a common data type,
// if they can be coercible to a common data type, return error.
let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema);
let (when_thens, else_expr) = match coerce_type {
None => Err(DataFusionError::Plan(format!(
"Can't get a common type for then {:?} and else {:?} expression",
when_thens, else_expr
))),
Some(data_type) => {
// cast then expr
let left = when_thens
.into_iter()
.map(|(when, then)| {
let then = try_cast(then, input_schema, data_type.clone())?;
Ok((when, then))
})
.collect::<Result<Vec<_>>>()?;
let right = match else_expr {
None => None,
Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
};

Ok((left, right))
}
}?;

Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
}

fn get_case_common_type(
when_thens: &[WhenThen],
else_expr: Option<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Option<DataType> {
let thens_type = when_thens
.iter()
.map(|when_then| {
let data_type = &when_then.1.data_type(input_schema).unwrap();
data_type.clone()
})
.collect::<Vec<_>>();
let else_type = match else_expr {
None => {
// case when then exprs must have one then value
thens_type[0].clone()
}
Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
};
thens_type
.iter()
.fold(Some(else_type), |left, right_type| match left {
None => None,
// TODO: now just use the `equal` coercion rule for case when. If find the issue, and
// refactor again.
Some(left_type) => comparison_eq_coercion(&left_type, right_type),
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -323,8 +380,9 @@ mod tests {

let expr = case(
Some(col("a", &schema)?),
&[(when1, then1), (when2, then2)],
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
Expand Down Expand Up @@ -353,8 +411,9 @@ mod tests {

let expr = case(
Some(col("a", &schema)?),
&[(when1, then1), (when2, then2)],
vec![(when1, then1), (when2, then2)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
Expand Down Expand Up @@ -387,8 +446,9 @@ mod tests {

let expr = case(
Some(col("a", &schema)?),
&[(when1, then1)],
vec![(when1, then1)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
Expand Down Expand Up @@ -424,7 +484,12 @@ mod tests {
)?;
let then2 = lit(456i32);

let expr = case(None, &[(when1, then1), (when2, then2)], None)?;
let expr = case(
None,
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -453,7 +518,7 @@ mod tests {
)?;
let x = lit(ScalarValue::Float64(None));

let expr = case(None, &[(when1, then1)], Some(x))?;
let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -496,7 +561,12 @@ mod tests {
let then2 = lit(456i32);
let else_value = lit(999i32);

let expr = case(None, &[(when1, then1), (when2, then2)], Some(else_value))?;
let expr = case(
None,
vec![(when1, then1), (when2, then2)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -526,7 +596,7 @@ mod tests {
let then = lit(123.3f64);
let else_value = lit(999i32);

let expr = case(None, &[(when, then)], Some(else_value))?;
let expr = case(None, vec![(when, then)], Some(else_value), schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -555,7 +625,7 @@ mod tests {
)?;
let then = col("load4", &schema)?;

let expr = case(None, &[(when, then)], None)?;
let expr = case(None, vec![(when, then)], None, schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand All @@ -580,7 +650,7 @@ mod tests {
let when = lit(1.77f64);
let then = col("load4", &schema)?;

let expr = case(Some(expr), &[(when, then)], None)?;
let expr = case(Some(expr), vec![(when, then)], None, schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -630,4 +700,67 @@ mod tests {
RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
Ok(batch)
}

#[test]
fn case_test_incompatible() -> Result<()> {
// 1 then is int64
// 2 then is boolean
let batch = case_test_batch()?;
let schema = batch.schema();

// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(true);

let expr = case(
None,
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
);
assert!(expr.is_err());

// then 1 is int32
// then 2 is int64
// else is float
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(456i64);
let else_expr = lit(1.23f64);

let expr = case(
None,
vec![(when1, then1), (when2, then2)],
Some(else_expr),
schema.as_ref(),
);
assert!(expr.is_ok());
let result_type = expr.unwrap().data_type(schema.as_ref())?;
assert_eq!(DataType::Float64, result_type);
Ok(())
}
}
11 changes: 5 additions & 6 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
use crate::expressions::try_cast;
use crate::{
execution_props::ExecutionProps,
expressions::{
self, binary, CaseExpr, Column, DateIntervalExpr, GetIndexedFieldExpr, Literal,
},
expressions::{self, binary, Column, DateIntervalExpr, GetIndexedFieldExpr, Literal},
functions, udf,
var_provider::VarType,
PhysicalExpr,
Expand Down Expand Up @@ -162,11 +160,12 @@ pub fn create_physical_expr(
} else {
None
};
Ok(Arc::new(CaseExpr::try_new(
Ok(expressions::case(
expr,
&when_then_expr,
when_then_expr,
else_expr,
)?))
input_schema,
)?)
}
Expr::Cast { expr, data_type } => expressions::cast(
create_physical_expr(expr, input_dfschema, input_schema, execution_props)?,
Expand Down

0 comments on commit da392f4

Please sign in to comment.