-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CaseWhen: coerce the all then and else data type to a common data type #2819
Merged
liukun4515
merged 5 commits into
apache:master
from
liukun4515:fix_case_when_datatype_#2818
Jul 4, 2022
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
13d538c
case when: coerce to the same data type for the result data type
liukun4515 00edadf
case when: support result type coerced
liukun4515 d50c671
Merge remote-tracking branch 'upstream/master' into fix_case_when_dat…
liukun4515 767f899
Merge remote-tracking branch 'upstream/master' into fix_case_when_dat…
liukun4515 de66710
change usage of literal
liukun4515 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -25,6 +25,7 @@ use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene}; | |||||
use arrow::datatypes::{DataType, Schema}; | ||||||
use arrow::record_batch::RecordBatch; | ||||||
use datafusion_common::{DataFusionError, Result}; | ||||||
use datafusion_expr::binary_rule::comparison_eq_coercion; | ||||||
use datafusion_expr::ColumnarValue; | ||||||
|
||||||
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>); | ||||||
|
@@ -76,7 +77,7 @@ impl CaseExpr { | |||||
/// Create a new CASE WHEN expression | ||||||
pub fn try_new( | ||||||
expr: Option<Arc<dyn PhysicalExpr>>, | ||||||
when_then_expr: &[WhenThen], | ||||||
when_then_expr: Vec<WhenThen>, | ||||||
else_expr: Option<Arc<dyn PhysicalExpr>>, | ||||||
) -> Result<Self> { | ||||||
if when_then_expr.is_empty() { | ||||||
|
@@ -86,7 +87,7 @@ impl CaseExpr { | |||||
} else { | ||||||
Ok(Self { | ||||||
expr, | ||||||
when_then_expr: when_then_expr.to_vec(), | ||||||
when_then_expr, | ||||||
else_expr, | ||||||
}) | ||||||
} | ||||||
|
@@ -291,12 +292,68 @@ impl PhysicalExpr for CaseExpr { | |||||
/// Create a CASE expression | ||||||
pub fn case( | ||||||
expr: Option<Arc<dyn PhysicalExpr>>, | ||||||
when_thens: &[WhenThen], | ||||||
when_thens: Vec<WhenThen>, | ||||||
else_expr: Option<Arc<dyn PhysicalExpr>>, | ||||||
input_schema: &Schema, | ||||||
) -> Result<Arc<dyn PhysicalExpr>> { | ||||||
// all the result of then and else should be convert to a common data type, | ||||||
// if they can be coercible to a common data type, return error. | ||||||
let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema); | ||||||
let (when_thens, else_expr) = match coerce_type { | ||||||
None => Err(DataFusionError::Plan(format!( | ||||||
"Can't get a common type for then {:?} and else {:?} expression", | ||||||
when_thens, else_expr | ||||||
))), | ||||||
Some(data_type) => { | ||||||
// cast then expr | ||||||
let left = when_thens | ||||||
.into_iter() | ||||||
.map(|(when, then)| { | ||||||
let then = try_cast(then, input_schema, data_type.clone())?; | ||||||
Ok((when, then)) | ||||||
}) | ||||||
.collect::<Result<Vec<_>>>()?; | ||||||
let right = match else_expr { | ||||||
None => None, | ||||||
Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?), | ||||||
}; | ||||||
|
||||||
Ok((left, right)) | ||||||
} | ||||||
}?; | ||||||
|
||||||
Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) | ||||||
} | ||||||
|
||||||
fn get_case_common_type( | ||||||
when_thens: &[WhenThen], | ||||||
else_expr: Option<Arc<dyn PhysicalExpr>>, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I wonder if there is a reason it needs to get a copy, or would a reference work too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reference is not work for this. |
||||||
input_schema: &Schema, | ||||||
) -> Option<DataType> { | ||||||
let thens_type = when_thens | ||||||
.iter() | ||||||
.map(|when_then| { | ||||||
let data_type = &when_then.1.data_type(input_schema).unwrap(); | ||||||
data_type.clone() | ||||||
}) | ||||||
.collect::<Vec<_>>(); | ||||||
let else_type = match else_expr { | ||||||
None => { | ||||||
// case when then exprs must have one then value | ||||||
thens_type[0].clone() | ||||||
} | ||||||
Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(), | ||||||
}; | ||||||
thens_type | ||||||
.iter() | ||||||
.fold(Some(else_type), |left, right_type| match left { | ||||||
None => None, | ||||||
// TODO: now just use the `equal` coercion rule for case when. If find the issue, and | ||||||
liukun4515 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
// refactor again. | ||||||
Some(left_type) => comparison_eq_coercion(&left_type, right_type), | ||||||
}) | ||||||
} | ||||||
|
||||||
#[cfg(test)] | ||||||
mod tests { | ||||||
use super::*; | ||||||
|
@@ -323,8 +380,9 @@ mod tests { | |||||
|
||||||
let expr = case( | ||||||
Some(col("a", &schema)?), | ||||||
&[(when1, then1), (when2, then2)], | ||||||
vec![(when1, then1), (when2, then2)], | ||||||
None, | ||||||
schema.as_ref(), | ||||||
)?; | ||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||
let result = result | ||||||
|
@@ -353,8 +411,9 @@ mod tests { | |||||
|
||||||
let expr = case( | ||||||
Some(col("a", &schema)?), | ||||||
&[(when1, then1), (when2, then2)], | ||||||
vec![(when1, then1), (when2, then2)], | ||||||
Some(else_value), | ||||||
schema.as_ref(), | ||||||
)?; | ||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||
let result = result | ||||||
|
@@ -387,8 +446,9 @@ mod tests { | |||||
|
||||||
let expr = case( | ||||||
Some(col("a", &schema)?), | ||||||
&[(when1, then1)], | ||||||
vec![(when1, then1)], | ||||||
Some(else_value), | ||||||
schema.as_ref(), | ||||||
)?; | ||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||
let result = result | ||||||
|
@@ -424,7 +484,12 @@ mod tests { | |||||
)?; | ||||||
let then2 = lit(456i32); | ||||||
|
||||||
let expr = case(None, &[(when1, then1), (when2, then2)], None)?; | ||||||
let expr = case( | ||||||
None, | ||||||
vec![(when1, then1), (when2, then2)], | ||||||
None, | ||||||
schema.as_ref(), | ||||||
)?; | ||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||
let result = result | ||||||
.as_any() | ||||||
|
@@ -453,7 +518,7 @@ mod tests { | |||||
)?; | ||||||
let x = lit(ScalarValue::Float64(None)); | ||||||
|
||||||
let expr = case(None, &[(when1, then1)], Some(x))?; | ||||||
let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?; | ||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||
let result = result | ||||||
.as_any() | ||||||
|
@@ -496,7 +561,12 @@ mod tests { | |||||
let then2 = lit(456i32); | ||||||
let else_value = lit(999i32); | ||||||
|
||||||
let expr = case(None, &[(when1, then1), (when2, then2)], Some(else_value))?; | ||||||
let expr = case( | ||||||
None, | ||||||
vec![(when1, then1), (when2, then2)], | ||||||
Some(else_value), | ||||||
schema.as_ref(), | ||||||
)?; | ||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||
let result = result | ||||||
.as_any() | ||||||
|
@@ -526,7 +596,7 @@ mod tests { | |||||
let then = lit(123.3f64); | ||||||
let else_value = lit(999i32); | ||||||
|
||||||
let expr = case(None, &[(when, then)], Some(else_value))?; | ||||||
let expr = case(None, vec![(when, then)], Some(else_value), schema.as_ref())?; | ||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||
let result = result | ||||||
.as_any() | ||||||
|
@@ -555,7 +625,7 @@ mod tests { | |||||
)?; | ||||||
let then = col("load4", &schema)?; | ||||||
|
||||||
let expr = case(None, &[(when, then)], None)?; | ||||||
let expr = case(None, vec![(when, then)], None, schema.as_ref())?; | ||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||
let result = result | ||||||
.as_any() | ||||||
|
@@ -580,7 +650,7 @@ mod tests { | |||||
let when = lit(1.77f64); | ||||||
let then = col("load4", &schema)?; | ||||||
|
||||||
let expr = case(Some(expr), &[(when, then)], None)?; | ||||||
let expr = case(Some(expr), vec![(when, then)], None, schema.as_ref())?; | ||||||
let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); | ||||||
let result = result | ||||||
.as_any() | ||||||
|
@@ -630,4 +700,67 @@ mod tests { | |||||
RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?; | ||||||
Ok(batch) | ||||||
} | ||||||
|
||||||
#[test] | ||||||
fn case_test_incompatible() -> Result<()> { | ||||||
// 1 then is int64 | ||||||
// 2 then is boolean | ||||||
let batch = case_test_batch()?; | ||||||
let schema = batch.schema(); | ||||||
|
||||||
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END | ||||||
let when1 = binary( | ||||||
col("a", &schema)?, | ||||||
Operator::Eq, | ||||||
lit("foo"), | ||||||
&batch.schema(), | ||||||
)?; | ||||||
let then1 = lit(123i32); | ||||||
let when2 = binary( | ||||||
col("a", &schema)?, | ||||||
Operator::Eq, | ||||||
lit("bar"), | ||||||
&batch.schema(), | ||||||
)?; | ||||||
let then2 = lit(true); | ||||||
|
||||||
let expr = case( | ||||||
None, | ||||||
vec![(when1, then1), (when2, then2)], | ||||||
None, | ||||||
schema.as_ref(), | ||||||
); | ||||||
assert!(expr.is_err()); | ||||||
|
||||||
// then 1 is int32 | ||||||
// then 2 is int64 | ||||||
// else is float | ||||||
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END | ||||||
let when1 = binary( | ||||||
col("a", &schema)?, | ||||||
Operator::Eq, | ||||||
lit("foo"), | ||||||
&batch.schema(), | ||||||
)?; | ||||||
let then1 = lit(123i32); | ||||||
let when2 = binary( | ||||||
col("a", &schema)?, | ||||||
Operator::Eq, | ||||||
lit("bar"), | ||||||
&batch.schema(), | ||||||
)?; | ||||||
let then2 = lit(456i64); | ||||||
let else_expr = lit(1.23f64); | ||||||
|
||||||
let expr = case( | ||||||
None, | ||||||
vec![(when1, then1), (when2, then2)], | ||||||
Some(else_expr), | ||||||
schema.as_ref(), | ||||||
); | ||||||
assert!(expr.is_ok()); | ||||||
let result_type = expr.unwrap().data_type(schema.as_ref())?; | ||||||
assert_eq!(DataType::Float64, result_type); | ||||||
Ok(()) | ||||||
} | ||||||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍