Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CaseWhen: coerce the all then and else data type to a common data type #2819

Merged
merged 5 commits into from
Jul 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 145 additions & 12 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::binary_rule::comparison_eq_coercion;
use datafusion_expr::ColumnarValue;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
Expand Down Expand Up @@ -76,7 +77,7 @@ impl CaseExpr {
/// Create a new CASE WHEN expression
pub fn try_new(
expr: Option<Arc<dyn PhysicalExpr>>,
when_then_expr: &[WhenThen],
when_then_expr: Vec<WhenThen>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

else_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Result<Self> {
if when_then_expr.is_empty() {
Expand All @@ -86,7 +87,7 @@ impl CaseExpr {
} else {
Ok(Self {
expr,
when_then_expr: when_then_expr.to_vec(),
when_then_expr,
else_expr,
})
}
Expand Down Expand Up @@ -291,12 +292,68 @@ impl PhysicalExpr for CaseExpr {
/// Create a CASE expression
pub fn case(
expr: Option<Arc<dyn PhysicalExpr>>,
when_thens: &[WhenThen],
when_thens: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
// all the result of then and else should be convert to a common data type,
// if they can be coercible to a common data type, return error.
let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema);
let (when_thens, else_expr) = match coerce_type {
None => Err(DataFusionError::Plan(format!(
"Can't get a common type for then {:?} and else {:?} expression",
when_thens, else_expr
))),
Some(data_type) => {
// cast then expr
let left = when_thens
.into_iter()
.map(|(when, then)| {
let then = try_cast(then, input_schema, data_type.clone())?;
Ok((when, then))
})
.collect::<Result<Vec<_>>>()?;
let right = match else_expr {
None => None,
Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
};

Ok((left, right))
}
}?;

Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
}

fn get_case_common_type(
when_thens: &[WhenThen],
else_expr: Option<Arc<dyn PhysicalExpr>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else_expr: Option<Arc<dyn PhysicalExpr>>,
else_expr: Option<&PhysicalExpr>,

I wonder if there is a reason it needs to get a copy, or would a reference work too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference is not work for this.

input_schema: &Schema,
) -> Option<DataType> {
let thens_type = when_thens
.iter()
.map(|when_then| {
let data_type = &when_then.1.data_type(input_schema).unwrap();
data_type.clone()
})
.collect::<Vec<_>>();
let else_type = match else_expr {
None => {
// case when then exprs must have one then value
thens_type[0].clone()
}
Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
};
thens_type
.iter()
.fold(Some(else_type), |left, right_type| match left {
None => None,
// TODO: now just use the `equal` coercion rule for case when. If find the issue, and
// refactor again.
Some(left_type) => comparison_eq_coercion(&left_type, right_type),
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -323,8 +380,9 @@ mod tests {

let expr = case(
Some(col("a", &schema)?),
&[(when1, then1), (when2, then2)],
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
Expand Down Expand Up @@ -353,8 +411,9 @@ mod tests {

let expr = case(
Some(col("a", &schema)?),
&[(when1, then1), (when2, then2)],
vec![(when1, then1), (when2, then2)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
Expand Down Expand Up @@ -387,8 +446,9 @@ mod tests {

let expr = case(
Some(col("a", &schema)?),
&[(when1, then1)],
vec![(when1, then1)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
Expand Down Expand Up @@ -424,7 +484,12 @@ mod tests {
)?;
let then2 = lit(456i32);

let expr = case(None, &[(when1, then1), (when2, then2)], None)?;
let expr = case(
None,
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -453,7 +518,7 @@ mod tests {
)?;
let x = lit(ScalarValue::Float64(None));

let expr = case(None, &[(when1, then1)], Some(x))?;
let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -496,7 +561,12 @@ mod tests {
let then2 = lit(456i32);
let else_value = lit(999i32);

let expr = case(None, &[(when1, then1), (when2, then2)], Some(else_value))?;
let expr = case(
None,
vec![(when1, then1), (when2, then2)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -526,7 +596,7 @@ mod tests {
let then = lit(123.3f64);
let else_value = lit(999i32);

let expr = case(None, &[(when, then)], Some(else_value))?;
let expr = case(None, vec![(when, then)], Some(else_value), schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -555,7 +625,7 @@ mod tests {
)?;
let then = col("load4", &schema)?;

let expr = case(None, &[(when, then)], None)?;
let expr = case(None, vec![(when, then)], None, schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand All @@ -580,7 +650,7 @@ mod tests {
let when = lit(1.77f64);
let then = col("load4", &schema)?;

let expr = case(Some(expr), &[(when, then)], None)?;
let expr = case(Some(expr), vec![(when, then)], None, schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
Expand Down Expand Up @@ -630,4 +700,67 @@ mod tests {
RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
Ok(batch)
}

#[test]
fn case_test_incompatible() -> Result<()> {
// 1 then is int64
// 2 then is boolean
let batch = case_test_batch()?;
let schema = batch.schema();

// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(true);

let expr = case(
None,
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
);
assert!(expr.is_err());

// then 1 is int32
// then 2 is int64
// else is float
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(456i64);
let else_expr = lit(1.23f64);

let expr = case(
None,
vec![(when1, then1), (when2, then2)],
Some(else_expr),
schema.as_ref(),
);
assert!(expr.is_ok());
let result_type = expr.unwrap().data_type(schema.as_ref())?;
assert_eq!(DataType::Float64, result_type);
Ok(())
}
}
11 changes: 5 additions & 6 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
use crate::expressions::try_cast;
use crate::{
execution_props::ExecutionProps,
expressions::{
self, binary, CaseExpr, Column, DateIntervalExpr, GetIndexedFieldExpr, Literal,
},
expressions::{self, binary, Column, DateIntervalExpr, GetIndexedFieldExpr, Literal},
functions, udf,
var_provider::VarType,
PhysicalExpr,
Expand Down Expand Up @@ -162,11 +160,12 @@ pub fn create_physical_expr(
} else {
None
};
Ok(Arc::new(CaseExpr::try_new(
Ok(expressions::case(
expr,
&when_then_expr,
when_then_expr,
else_expr,
)?))
input_schema,
)?)
}
Expr::Cast { expr, data_type } => expressions::cast(
create_physical_expr(expr, input_dfschema, input_schema, execution_props)?,
Expand Down