-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CaseWhen: coerce the all then and else data type to a common data type #2819
Changes from 3 commits
13d538c
00edadf
d50c671
767f899
de66710
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -25,6 +25,7 @@ use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene}; | |||||||||||||||||||||||||||||
use arrow::datatypes::{DataType, Schema}; | ||||||||||||||||||||||||||||||
use arrow::record_batch::RecordBatch; | ||||||||||||||||||||||||||||||
use datafusion_common::{DataFusionError, Result}; | ||||||||||||||||||||||||||||||
use datafusion_expr::binary_rule::comparison_eq_coercion; | ||||||||||||||||||||||||||||||
use datafusion_expr::ColumnarValue; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>); | ||||||||||||||||||||||||||||||
|
@@ -76,7 +77,7 @@ impl CaseExpr { | |||||||||||||||||||||||||||||
/// Create a new CASE WHEN expression | ||||||||||||||||||||||||||||||
pub fn try_new( | ||||||||||||||||||||||||||||||
expr: Option<Arc<dyn PhysicalExpr>>, | ||||||||||||||||||||||||||||||
when_then_expr: &[WhenThen], | ||||||||||||||||||||||||||||||
when_then_expr: Vec<WhenThen>, | ||||||||||||||||||||||||||||||
else_expr: Option<Arc<dyn PhysicalExpr>>, | ||||||||||||||||||||||||||||||
) -> Result<Self> { | ||||||||||||||||||||||||||||||
if when_then_expr.is_empty() { | ||||||||||||||||||||||||||||||
|
@@ -86,7 +87,7 @@ impl CaseExpr { | |||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||
Ok(Self { | ||||||||||||||||||||||||||||||
expr, | ||||||||||||||||||||||||||||||
when_then_expr: when_then_expr.to_vec(), | ||||||||||||||||||||||||||||||
when_then_expr, | ||||||||||||||||||||||||||||||
else_expr, | ||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
@@ -291,12 +292,67 @@ impl PhysicalExpr for CaseExpr { | |||||||||||||||||||||||||||||
/// Create a CASE expression | ||||||||||||||||||||||||||||||
pub fn case( | ||||||||||||||||||||||||||||||
expr: Option<Arc<dyn PhysicalExpr>>, | ||||||||||||||||||||||||||||||
when_thens: &[WhenThen], | ||||||||||||||||||||||||||||||
when_thens: Vec<WhenThen>, | ||||||||||||||||||||||||||||||
else_expr: Option<Arc<dyn PhysicalExpr>>, | ||||||||||||||||||||||||||||||
input_schema: &Schema, | ||||||||||||||||||||||||||||||
) -> Result<Arc<dyn PhysicalExpr>> { | ||||||||||||||||||||||||||||||
// all the result of then and else should be convert to a common data type, | ||||||||||||||||||||||||||||||
// if they can be coercible to a common data type, return error. | ||||||||||||||||||||||||||||||
let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema); | ||||||||||||||||||||||||||||||
let (when_thens, else_expr) = match coerce_type { | ||||||||||||||||||||||||||||||
None => Err(DataFusionError::Plan(format!( | ||||||||||||||||||||||||||||||
"Can't get a common type for then {:?} and else {:?} expression", | ||||||||||||||||||||||||||||||
when_thens, else_expr | ||||||||||||||||||||||||||||||
))), | ||||||||||||||||||||||||||||||
Some(data_type) => { | ||||||||||||||||||||||||||||||
// cast then expr | ||||||||||||||||||||||||||||||
let left = when_thens | ||||||||||||||||||||||||||||||
.into_iter() | ||||||||||||||||||||||||||||||
.map(|(when, then)| { | ||||||||||||||||||||||||||||||
let then = try_cast(then, input_schema, data_type.clone()).unwrap(); | ||||||||||||||||||||||||||||||
(when, then) | ||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||
.collect::<Vec<WhenThen>>(); | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we sure this will always return without error (aka that this will not panic)? Since this function can return an
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch |
||||||||||||||||||||||||||||||
let right = match else_expr { | ||||||||||||||||||||||||||||||
None => None, | ||||||||||||||||||||||||||||||
Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?), | ||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||
Ok((left, right)) | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
}?; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
fn get_case_common_type( | ||||||||||||||||||||||||||||||
when_thens: &[WhenThen], | ||||||||||||||||||||||||||||||
else_expr: Option<Arc<dyn PhysicalExpr>>, | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I wonder if there is a reason it needs to get a copy, or would a reference work too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reference is not work for this. |
||||||||||||||||||||||||||||||
input_schema: &Schema, | ||||||||||||||||||||||||||||||
) -> Option<DataType> { | ||||||||||||||||||||||||||||||
let thens_type = when_thens | ||||||||||||||||||||||||||||||
.iter() | ||||||||||||||||||||||||||||||
.map(|when_then| { | ||||||||||||||||||||||||||||||
let data_type = &when_then.1.data_type(input_schema).unwrap(); | ||||||||||||||||||||||||||||||
data_type.clone() | ||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||
.collect::<Vec<_>>(); | ||||||||||||||||||||||||||||||
let else_type = match else_expr { | ||||||||||||||||||||||||||||||
None => { | ||||||||||||||||||||||||||||||
// case when then exprs must have one then value | ||||||||||||||||||||||||||||||
thens_type[0].clone() | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(), | ||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||
thens_type | ||||||||||||||||||||||||||||||
.iter() | ||||||||||||||||||||||||||||||
.fold(Some(else_type), |left, right_type| match left { | ||||||||||||||||||||||||||||||
None => None, | ||||||||||||||||||||||||||||||
// TODO: now just use the `equal` coercion rule for case when. If find the issue, and | ||||||||||||||||||||||||||||||
liukun4515 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
// refactor again. | ||||||||||||||||||||||||||||||
Some(left_type) => comparison_eq_coercion(&left_type, right_type), | ||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
#[cfg(test)] | ||||||||||||||||||||||||||||||
mod tests { | ||||||||||||||||||||||||||||||
use super::*; | ||||||||||||||||||||||||||||||
|
@@ -323,8 +379,9 @@ mod tests { | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case( | ||||||||||||||||||||||||||||||
Some(col("a", &schema)?), | ||||||||||||||||||||||||||||||
&[(when1, then1), (when2, then2)], | ||||||||||||||||||||||||||||||
vec![(when1, then1), (when2, then2)], | ||||||||||||||||||||||||||||||
None, | ||||||||||||||||||||||||||||||
schema.as_ref(), | ||||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||||||||||||||||||||||||||
let result = result | ||||||||||||||||||||||||||||||
|
@@ -353,8 +410,9 @@ mod tests { | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case( | ||||||||||||||||||||||||||||||
Some(col("a", &schema)?), | ||||||||||||||||||||||||||||||
&[(when1, then1), (when2, then2)], | ||||||||||||||||||||||||||||||
vec![(when1, then1), (when2, then2)], | ||||||||||||||||||||||||||||||
Some(else_value), | ||||||||||||||||||||||||||||||
schema.as_ref(), | ||||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||||||||||||||||||||||||||
let result = result | ||||||||||||||||||||||||||||||
|
@@ -387,8 +445,9 @@ mod tests { | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case( | ||||||||||||||||||||||||||||||
Some(col("a", &schema)?), | ||||||||||||||||||||||||||||||
&[(when1, then1)], | ||||||||||||||||||||||||||||||
vec![(when1, then1)], | ||||||||||||||||||||||||||||||
Some(else_value), | ||||||||||||||||||||||||||||||
schema.as_ref(), | ||||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||||||||||||||||||||||||||
let result = result | ||||||||||||||||||||||||||||||
|
@@ -424,7 +483,12 @@ mod tests { | |||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let then2 = lit(ScalarValue::Int32(Some(456))); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case(None, &[(when1, then1), (when2, then2)], None)?; | ||||||||||||||||||||||||||||||
let expr = case( | ||||||||||||||||||||||||||||||
None, | ||||||||||||||||||||||||||||||
vec![(when1, then1), (when2, then2)], | ||||||||||||||||||||||||||||||
None, | ||||||||||||||||||||||||||||||
schema.as_ref(), | ||||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||||||||||||||||||||||||||
let result = result | ||||||||||||||||||||||||||||||
.as_any() | ||||||||||||||||||||||||||||||
|
@@ -458,7 +522,7 @@ mod tests { | |||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let x = lit(ScalarValue::Float64(None)); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case(None, &[(when1, then1)], Some(x))?; | ||||||||||||||||||||||||||||||
let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?; | ||||||||||||||||||||||||||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||||||||||||||||||||||||||
let result = result | ||||||||||||||||||||||||||||||
.as_any() | ||||||||||||||||||||||||||||||
|
@@ -501,7 +565,12 @@ mod tests { | |||||||||||||||||||||||||||||
let then2 = lit(ScalarValue::Int32(Some(456))); | ||||||||||||||||||||||||||||||
let else_value = lit(ScalarValue::Int32(Some(999))); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case(None, &[(when1, then1), (when2, then2)], Some(else_value))?; | ||||||||||||||||||||||||||||||
let expr = case( | ||||||||||||||||||||||||||||||
None, | ||||||||||||||||||||||||||||||
vec![(when1, then1), (when2, then2)], | ||||||||||||||||||||||||||||||
Some(else_value), | ||||||||||||||||||||||||||||||
schema.as_ref(), | ||||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||||||||||||||||||||||||||
let result = result | ||||||||||||||||||||||||||||||
.as_any() | ||||||||||||||||||||||||||||||
|
@@ -531,7 +600,7 @@ mod tests { | |||||||||||||||||||||||||||||
let then = lit(ScalarValue::Float64(Some(123.3))); | ||||||||||||||||||||||||||||||
let else_value = lit(ScalarValue::Int32(Some(999))); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case(None, &[(when, then)], Some(else_value))?; | ||||||||||||||||||||||||||||||
let expr = case(None, vec![(when, then)], Some(else_value), schema.as_ref())?; | ||||||||||||||||||||||||||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||||||||||||||||||||||||||
let result = result | ||||||||||||||||||||||||||||||
.as_any() | ||||||||||||||||||||||||||||||
|
@@ -560,7 +629,7 @@ mod tests { | |||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let then = col("load4", &schema)?; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case(None, &[(when, then)], None)?; | ||||||||||||||||||||||||||||||
let expr = case(None, vec![(when, then)], None, schema.as_ref())?; | ||||||||||||||||||||||||||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||||||||||||||||||||||||||
let result = result | ||||||||||||||||||||||||||||||
.as_any() | ||||||||||||||||||||||||||||||
|
@@ -585,7 +654,7 @@ mod tests { | |||||||||||||||||||||||||||||
let when = lit(ScalarValue::Float64(Some(1.77))); | ||||||||||||||||||||||||||||||
let then = col("load4", &schema)?; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case(Some(expr), &[(when, then)], None)?; | ||||||||||||||||||||||||||||||
let expr = case(Some(expr), vec![(when, then)], None, schema.as_ref())?; | ||||||||||||||||||||||||||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||||||||||||||||||||||||||
let result = result | ||||||||||||||||||||||||||||||
.as_any() | ||||||||||||||||||||||||||||||
|
@@ -635,4 +704,67 @@ mod tests { | |||||||||||||||||||||||||||||
RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?; | ||||||||||||||||||||||||||||||
Ok(batch) | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
#[test] | ||||||||||||||||||||||||||||||
fn case_test_incompatible() -> Result<()> { | ||||||||||||||||||||||||||||||
// 1 then is int64 | ||||||||||||||||||||||||||||||
// 2 then is boolean | ||||||||||||||||||||||||||||||
let batch = case_test_batch()?; | ||||||||||||||||||||||||||||||
let schema = batch.schema(); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END | ||||||||||||||||||||||||||||||
let when1 = binary( | ||||||||||||||||||||||||||||||
col("a", &schema)?, | ||||||||||||||||||||||||||||||
Operator::Eq, | ||||||||||||||||||||||||||||||
lit(ScalarValue::Utf8(Some("foo".to_string()))), | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice if the the
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will rebase this pr after #2828 merged There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed the code to use |
||||||||||||||||||||||||||||||
&batch.schema(), | ||||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let then1 = lit(ScalarValue::Int32(Some(123))); | ||||||||||||||||||||||||||||||
let when2 = binary( | ||||||||||||||||||||||||||||||
col("a", &schema)?, | ||||||||||||||||||||||||||||||
Operator::Eq, | ||||||||||||||||||||||||||||||
lit(ScalarValue::Utf8(Some("bar".to_string()))), | ||||||||||||||||||||||||||||||
&batch.schema(), | ||||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let then2 = lit(ScalarValue::Boolean(Some(true))); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case( | ||||||||||||||||||||||||||||||
None, | ||||||||||||||||||||||||||||||
vec![(when1, then1), (when2, then2)], | ||||||||||||||||||||||||||||||
None, | ||||||||||||||||||||||||||||||
schema.as_ref(), | ||||||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||||||
assert!(expr.is_err()); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
// then 1 is int32 | ||||||||||||||||||||||||||||||
// then 2 is int64 | ||||||||||||||||||||||||||||||
// else is float | ||||||||||||||||||||||||||||||
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END | ||||||||||||||||||||||||||||||
let when1 = binary( | ||||||||||||||||||||||||||||||
col("a", &schema)?, | ||||||||||||||||||||||||||||||
Operator::Eq, | ||||||||||||||||||||||||||||||
lit(ScalarValue::Utf8(Some("foo".to_string()))), | ||||||||||||||||||||||||||||||
&batch.schema(), | ||||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let then1 = lit(ScalarValue::Int32(Some(123))); | ||||||||||||||||||||||||||||||
let when2 = binary( | ||||||||||||||||||||||||||||||
col("a", &schema)?, | ||||||||||||||||||||||||||||||
Operator::Eq, | ||||||||||||||||||||||||||||||
lit(ScalarValue::Utf8(Some("bar".to_string()))), | ||||||||||||||||||||||||||||||
&batch.schema(), | ||||||||||||||||||||||||||||||
)?; | ||||||||||||||||||||||||||||||
let then2 = lit(ScalarValue::Int64(Some(456))); | ||||||||||||||||||||||||||||||
let else_expr = lit(ScalarValue::Float64(Some(1.23))); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
let expr = case( | ||||||||||||||||||||||||||||||
None, | ||||||||||||||||||||||||||||||
vec![(when1, then1), (when2, then2)], | ||||||||||||||||||||||||||||||
Some(else_expr), | ||||||||||||||||||||||||||||||
schema.as_ref(), | ||||||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||||||
assert!(expr.is_ok()); | ||||||||||||||||||||||||||||||
let result_type = expr.unwrap().data_type(schema.as_ref())?; | ||||||||||||||||||||||||||||||
assert_eq!(DataType::Float64, result_type); | ||||||||||||||||||||||||||||||
Ok(()) | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍