diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 75566208e307..da2f396e8ebb 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1956,7 +1956,7 @@ mod tests { use fmt::Debug; use std::collections::HashMap; use std::convert::TryFrom; - use std::ops::Not; + use std::ops::{BitAnd, Not}; use std::{any::Any, fmt}; fn make_session_state() -> SessionState { @@ -2140,18 +2140,17 @@ mod tests { async fn errors() -> Result<()> { let bool_expr = col("c1").eq(col("c1")); let cases = vec![ - // utf8 AND utf8 - col("c1").and(col("c1")), + // utf8 = utf8 + col("c1").eq(col("c1")), // u8 AND u8 - col("c3").and(col("c3")), - // utf8 = bool - col("c1").eq(bool_expr.clone()), - // u32 AND bool - col("c2").and(bool_expr), + col("c3").bitand(col("c3")), + // utf8 = u8 + col("c1").eq(col("c3")), + // bool AND bool + bool_expr.clone().and(bool_expr), ]; for case in cases { - let logical_plan = test_csv_scan().await?.project(vec![case.clone()]); - assert!(logical_plan.is_ok()); + test_csv_scan().await?.project(vec![case.clone()]).unwrap(); } Ok(()) } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 7c9179b2f38d..30ce15bd185c 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -25,9 +25,144 @@ use arrow::datatypes::{ use datafusion_common::DataFusionError; use datafusion_common::Result; -use crate::type_coercion::{is_datetime, is_decimal, is_interval, is_numeric}; +use crate::type_coercion::{is_decimal, is_numeric}; use crate::Operator; +/// The type signature of an instantiation of binary expression +struct Signature { + /// The type to coerce the left argument to + lhs: DataType, + /// The type to coerce the right argument to + rhs: DataType, + /// The return type of the expression + ret: DataType, +} + +impl Signature { + /// A signature where the inputs are coerced to the same type as the output + fn coerced(t: DataType) -> Self { + Self { + lhs: t.clone(), + rhs: t.clone(), + ret: t, + } + } + + /// A signature where the inputs are coerced to the same type with a boolean output + fn comparison(t: DataType) -> Self { + Self { + lhs: t.clone(), + rhs: t, + ret: DataType::Boolean, + } + } +} + +/// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs` +fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result { + match op { + Operator::Eq | + Operator::NotEq | + Operator::Lt | + Operator::LtEq | + Operator::Gt | + Operator::GtEq | + Operator::IsDistinctFrom | + Operator::IsNotDistinctFrom => { + comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { + DataFusionError::Plan(format!( + "Cannot infer common argument type for comparison operation {lhs} {op} {rhs}" + )) + }) + } + Operator::And | Operator::Or => match (lhs, rhs) { + // logical binary boolean operators can only be evaluated in bools or nulls + (DataType::Boolean, DataType::Boolean) + | (DataType::Null, DataType::Null) + | (DataType::Boolean, DataType::Null) + | (DataType::Null, DataType::Boolean) => Ok(Signature::coerced(DataType::Boolean)), + _ => Err(DataFusionError::Plan(format!( + "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" + ))), + }, + Operator::RegexMatch | + Operator::RegexIMatch | + Operator::RegexNotMatch | + Operator::RegexNotIMatch => { + regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { + DataFusionError::Plan(format!( + "Cannot infer common argument type for regex operation {lhs} {op} {rhs}" + )) + }) + } + Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft => { + bitwise_coercion(lhs, rhs).map(Signature::coerced).ok_or_else(|| { + DataFusionError::Plan(format!( + "Cannot infer common type for bitwise operation {lhs} {op} {rhs}" + )) + }) + } + Operator::StringConcat => { + string_concat_coercion(lhs, rhs).map(Signature::coerced).ok_or_else(|| { + DataFusionError::Plan(format!( + "Cannot infer common string type for string concat operation {lhs} {op} {rhs}" + )) + }) + } + Operator::Plus | + Operator::Minus | + Operator::Multiply | + Operator::Divide| + Operator::Modulo => { + // TODO: this logic would be easier to follow if the functions were inlined + if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) { + // Numeric arithmetic, e.g. Int32 + Int32 + Ok(Signature::coerced(numeric)) + } else if let Some(ret) = mathematics_temporal_result_type(lhs, rhs) { + // Temporal arithmetic, e.g. Date32 + Interval + Ok(Signature{ + lhs: lhs.clone(), + rhs: rhs.clone(), + ret, + }) + } else if let Some(coerced) = temporal_coercion(lhs, rhs) { + // Temporal arithmetic by first coercing to a common time representation + // e.g. Date32 - Timestamp + let ret = mathematics_temporal_result_type(&coerced, &coerced).ok_or_else(|| { + DataFusionError::Plan(format!( + "Cannot get result type for temporal operation {coerced} {op} {coerced}" + )) + })?; + Ok(Signature{ + lhs: coerced.clone(), + rhs: coerced, + ret, + }) + } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) { + // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0) + let ret = decimal_op_mathematics_type(op, &lhs, &rhs).ok_or_else(|| { + DataFusionError::Plan(format!( + "Cannot get result type for decimal operation {lhs} {op} {rhs}" + )) + })?; + Ok(Signature{ + lhs, + rhs, + ret, + }) + } else { + Err(DataFusionError::Plan(format!( + "Cannot coerce arithmetic expression {lhs} {op} {rhs} to valid types" + ))) + } + } + } +} + /// Returns the result type of applying mathematics operations such as /// `+` to arguments of `lhs_type` and `rhs_type`. fn mathematics_temporal_result_type( @@ -38,14 +173,6 @@ fn mathematics_temporal_result_type( use arrow::datatypes::IntervalUnit::*; use arrow::datatypes::TimeUnit::*; - if !is_interval(lhs_type) - && !is_interval(rhs_type) - && !is_datetime(lhs_type) - && !is_datetime(rhs_type) - { - return None; - }; - match (lhs_type, rhs_type) { // datetime +/- interval (Interval(_), Timestamp(_, _)) => Some(rhs_type.clone()), @@ -66,185 +193,64 @@ fn mathematics_temporal_result_type( | (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => { Some(Interval(MonthDayNano)) } - (Timestamp(_, _), Timestamp(_, _)) => None, // date - date (Date32, Date32) => Some(Interval(DayTime)), (Date64, Date64) => Some(Interval(MonthDayNano)), - (Date32, Date64) | (Date64, Date32) => Some(Interval(MonthDayNano)), - // date - timestamp, timestamp - date - (Date32, Timestamp(_, _)) - | (Timestamp(_, _), Date32) - | (Date64, Timestamp(_, _)) - | (Timestamp(_, _), Date64) => { - // TODO: make get_result_type must after coerce type. - // if type isn't coerced, we need get common type, and then get result type. - let common_type = temporal_coercion(lhs_type, rhs_type); - common_type.and_then(|t| mathematics_temporal_result_type(&t, &t)) - } _ => None, } } /// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types pub fn get_result_type( - lhs_type: &DataType, + lhs: &DataType, op: &Operator, - rhs_type: &DataType, + rhs: &DataType, ) -> Result { - if op.is_numerical_operators() && any_decimal(lhs_type, rhs_type) { - let (coerced_lhs_type, coerced_rhs_type) = - math_decimal_coercion(lhs_type, rhs_type); - - let lhs_type = coerced_lhs_type.unwrap_or(lhs_type.clone()); - let rhs_type = coerced_rhs_type.unwrap_or(rhs_type.clone()); - - if op.is_numerical_operators() { - if let Some(result_type) = - decimal_op_mathematics_type(op, &lhs_type, &rhs_type) - { - return Ok(result_type); - } - } - } - let result = match op { - Operator::And - | Operator::Or - | Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::Gt - | Operator::GtEq - | Operator::LtEq - | Operator::RegexMatch - | Operator::RegexIMatch - | Operator::RegexNotMatch - | Operator::RegexNotIMatch - | Operator::IsDistinctFrom - | Operator::IsNotDistinctFrom => Some(DataType::Boolean), - Operator::Plus | Operator::Minus - if is_datetime(lhs_type) && is_datetime(rhs_type) - || (is_interval(lhs_type) && is_interval(rhs_type)) - || (is_datetime(lhs_type) && is_interval(rhs_type)) - || (is_interval(lhs_type) && is_datetime(rhs_type)) => - { - mathematics_temporal_result_type(lhs_type, rhs_type) - } - // following same with `coerce_types` - Operator::BitwiseAnd - | Operator::BitwiseOr - | Operator::BitwiseXor - | Operator::BitwiseShiftRight - | Operator::BitwiseShiftLeft => bitwise_coercion(lhs_type, rhs_type), - Operator::Plus - | Operator::Minus - | Operator::Modulo - | Operator::Divide - | Operator::Multiply => mathematics_numerical_coercion(lhs_type, rhs_type), - Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type), - }; - - result.ok_or(DataFusionError::Plan(format!( - "Unsupported argument types. Can not evaluate {lhs_type:?} {op} {rhs_type:?}" - ))) + signature(lhs, op, rhs).map(|sig| sig.ret) } -/// Coercion rules for all binary operators. Returns the 'coerce_types' -/// is returns the type the arguments should be coerced to -/// -/// Returns None if no suitable type can be found. -pub fn coerce_types( - lhs_type: &DataType, +/// Returns the coerced input types for a binary expression evaluating the `op` with the left and right hand types +pub fn get_input_types( + lhs: &DataType, op: &Operator, - rhs_type: &DataType, -) -> Result { - // This result MUST be compatible with `binary_coerce` - let result = match op { - Operator::BitwiseAnd - | Operator::BitwiseOr - | Operator::BitwiseXor - | Operator::BitwiseShiftRight - | Operator::BitwiseShiftLeft => bitwise_coercion(lhs_type, rhs_type), - Operator::And | Operator::Or => match (lhs_type, rhs_type) { - // logical binary boolean operators can only be evaluated in bools or nulls - (DataType::Boolean, DataType::Boolean) - | (DataType::Null, DataType::Null) - | (DataType::Boolean, DataType::Null) - | (DataType::Null, DataType::Boolean) => Some(DataType::Boolean), - _ => None, - }, - // logical comparison operators have their own rules, and always return a boolean - Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::Gt - | Operator::GtEq - | Operator::LtEq - | Operator::IsDistinctFrom - | Operator::IsNotDistinctFrom => comparison_coercion(lhs_type, rhs_type), - Operator::Plus | Operator::Minus - if is_interval(lhs_type) && is_interval(rhs_type) => - { - temporal_coercion(lhs_type, rhs_type) - } - Operator::Minus if is_datetime(lhs_type) && is_datetime(rhs_type) => { - temporal_coercion(lhs_type, rhs_type) - } - // for math expressions, the final value of the coercion is also the return type - // because coercion favours higher information types - Operator::Plus - | Operator::Minus - | Operator::Modulo - | Operator::Divide - | Operator::Multiply => mathematics_numerical_coercion(lhs_type, rhs_type), - Operator::RegexMatch - | Operator::RegexIMatch - | Operator::RegexNotMatch - | Operator::RegexNotIMatch => regex_coercion(lhs_type, rhs_type), - // "||" operator has its own rules, and always return a string type - Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type), - }; - - // re-write the error message of failed coercions to include the operator's information - result.ok_or(DataFusionError::Plan(format!("{lhs_type:?} {op} {rhs_type:?} can't be evaluated because there isn't a common type to coerce the types to"))) + rhs: &DataType, +) -> Result<(DataType, DataType)> { + signature(lhs, op, rhs).map(|sig| (sig.lhs, sig.rhs)) } /// Coercion rules for mathematics operators between decimal and non-decimal types. pub fn math_decimal_coercion( lhs_type: &DataType, rhs_type: &DataType, -) -> (Option, Option) { +) -> Option<(DataType, DataType)> { use arrow::datatypes::DataType::*; - if both_decimal(lhs_type, rhs_type) { - return (None, None); - } - match (lhs_type, rhs_type) { - (Null, dec_type @ Decimal128(_, _)) => (Some(dec_type.clone()), None), - (dec_type @ Decimal128(_, _), Null) => (None, Some(dec_type.clone())), (Dictionary(key_type, value_type), _) => { - let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type); - let lhs_type = value_type - .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type))); - (lhs_type, rhs_type) + let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?; + Some((Dictionary(key_type.clone(), Box::new(value_type)), rhs_type)) } (_, Dictionary(key_type, value_type)) => { - let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type); - let rhs_type = value_type - .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type))); - (lhs_type, rhs_type) + let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?; + Some((lhs_type, Dictionary(key_type.clone(), Box::new(value_type)))) + } + (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { + Some((dec_type.clone(), dec_type.clone())) + } + (Decimal128(_, _), Decimal128(_, _)) => { + Some((lhs_type.clone(), rhs_type.clone())) + } + (Decimal128(_, _), Float32 | Float64) | (Float32 | Float64, Decimal128(_, _)) => { + Some((Float64, Float64)) } - (Decimal128(_, _), Float32 | Float64) => (Some(Float64), Some(Float64)), - (Float32 | Float64, Decimal128(_, _)) => (Some(Float64), Some(Float64)), (Decimal128(_, _), _) => { - let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type); - (None, converted_decimal_type) + Some((lhs_type.clone(), coerce_numeric_type_to_decimal(rhs_type)?)) } (_, Decimal128(_, _)) => { - let converted_decimal_type = coerce_numeric_type_to_decimal(lhs_type); - (converted_decimal_type, None) + Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone())) } - _ => (None, None), + + _ => None, } } @@ -289,9 +295,7 @@ pub(crate) fn bitwise_coercion( } } -/// Returns the output type of applying comparison operations such as -/// `eq`, `not eq`, `lt`, `lteq`, `gt`, and `gteq` to arguments -/// of `lhs_type` and `rhs_type`. +/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { if lhs_type == rhs_type { // same type => equality is possible @@ -305,9 +309,8 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { @@ -319,8 +322,8 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; - let other_decimal_type = &match other_type { - // This conversion rule is from spark - // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 - Int8 => Decimal128(3, 0), - Int16 => Decimal128(5, 0), - Int32 => Decimal128(10, 0), - Int64 => Decimal128(20, 0), - Float32 => Decimal128(14, 7), - Float64 => Decimal128(30, 15), - _ => { - return None; - } - }; + let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?; match (decimal_type, &other_decimal_type) { (d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2), _ => None, @@ -430,6 +421,8 @@ fn get_wider_decimal_type( /// Now, we just support the signed integer type and floating-point type. fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { use arrow::datatypes::DataType::*; + // This conversion rule is from spark + // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 match numeric_type { Int8 => Some(Decimal128(3, 0)), Int16 => Some(Decimal128(5, 0)), @@ -499,6 +492,7 @@ pub fn coercion_decimal_mathematics_type( left_decimal_type: &DataType, right_decimal_type: &DataType, ) -> Option { + // TODO: Move this logic into kernel implementations use arrow::datatypes::DataType::*; match (left_decimal_type, right_decimal_type) { // The promotion rule from spark @@ -518,7 +512,7 @@ pub fn coercion_decimal_mathematics_type( } } -/// Returns the output type of applying mathematics operations on decimal types. +/// Returns the output type of applying mathematics operations on two decimal types. /// The rule is from spark. Note that this is different to the coerced type applied /// to two sides of the arithmetic operation. pub fn decimal_op_mathematics_type( @@ -605,29 +599,6 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> } } -/// Determine if at least of one of lhs and rhs is decimal, and the other must be NULL or decimal -fn both_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool { - use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (_, Null) => is_decimal(lhs_type), - (Null, _) => is_decimal(rhs_type), - (Decimal128(_, _), Decimal128(_, _)) => true, - (Dictionary(_, value_type), _) => is_decimal(value_type) && is_decimal(rhs_type), - (_, Dictionary(_, value_type)) => is_decimal(lhs_type) && is_decimal(value_type), - _ => false, - } -} - -/// Determine if at least of one of lhs and rhs is decimal -pub fn any_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool { - use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (Dictionary(_, value_type), _) => is_decimal(value_type) || is_decimal(rhs_type), - (_, Dictionary(_, value_type)) => is_decimal(lhs_type) || is_decimal(value_type), - (_, _) => is_decimal(lhs_type) || is_decimal(rhs_type), - } -} - /// Coercion rules for Dictionaries: the type that both lhs and rhs /// can be casted to for the purpose of a computation. /// @@ -743,9 +714,6 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Interval(MonthDayNano)), @@ -832,7 +800,6 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { mod tests { use arrow::datatypes::DataType; - use datafusion_common::assert_contains; use datafusion_common::DataFusionError; use datafusion_common::Result; @@ -843,10 +810,13 @@ mod tests { #[test] fn test_coercion_error() -> Result<()> { let result_type = - coerce_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8); + get_input_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8); if let Err(DataFusionError::Plan(e)) = result_type { - assert_eq!(e, "Float32 + Utf8 can't be evaluated because there isn't a common type to coerce the types to"); + assert_eq!( + e, + "Cannot coerce arithmetic expression Float32 + Utf8 to valid types" + ); Ok(()) } else { Err(DataFusionError::Internal( @@ -891,12 +861,14 @@ mod tests { for (i, input_type) in input_types.iter().enumerate() { let expect_type = &result_types[i]; for op in comparison_op_types { - let result_type = coerce_types(&input_decimal, &op, input_type)?; - assert_eq!(expect_type, &result_type); + let (lhs, rhs) = get_input_types(&input_decimal, &op, input_type)?; + assert_eq!(expect_type, &lhs); + assert_eq!(expect_type, &rhs); } } // negative test - let result_type = coerce_types(&input_decimal, &Operator::Eq, &DataType::Boolean); + let result_type = + get_input_types(&input_decimal, &Operator::Eq, &DataType::Boolean); assert!(result_type.is_err()); Ok(()) } @@ -1017,24 +989,27 @@ mod tests { macro_rules! test_coercion_binary_rule { ($A_TYPE:expr, $B_TYPE:expr, $OP:expr, $C_TYPE:expr) => {{ - let result = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?; - assert_eq!(result, $C_TYPE); + let (lhs, rhs) = get_input_types(&$A_TYPE, &$OP, &$B_TYPE)?; + assert_eq!(lhs, $C_TYPE); + assert_eq!(rhs, $C_TYPE); }}; } #[test] fn test_date_timestamp_arithmetic_error() -> Result<()> { - let common_type = coerce_types( + let (lhs, rhs) = get_input_types( &DataType::Timestamp(TimeUnit::Nanosecond, None), &Operator::Minus, &DataType::Timestamp(TimeUnit::Millisecond, None), )?; - assert_eq!(common_type.to_string(), "Timestamp(Millisecond, None)"); + assert_eq!(lhs.to_string(), "Timestamp(Millisecond, None)"); + assert_eq!(rhs.to_string(), "Timestamp(Millisecond, None)"); - let err = coerce_types(&DataType::Date32, &Operator::Plus, &DataType::Date64) - .unwrap_err() - .to_string(); - assert_contains!(&err, "Date32 + Date64 can't be evaluated because there isn't a common type to coerce the types to"); + let (lhs, rhs) = + get_input_types(&DataType::Date32, &Operator::Plus, &DataType::Date64) + .unwrap(); + assert_eq!(lhs.to_string(), "Date64"); + assert_eq!(rhs.to_string(), "Date64"); Ok(()) } @@ -1234,18 +1209,15 @@ mod tests { lhs_type: DataType, rhs_type: DataType, mathematics_op: Operator, - expected_lhs_type: Option, - expected_rhs_type: Option, + expected_lhs_type: DataType, + expected_rhs_type: DataType, expected_coerced_type: Option, expected_output_type: DataType, ) { // The coerced types for lhs and rhs, if any of them is not decimal - let (l, r) = math_decimal_coercion(&lhs_type, &rhs_type); - assert_eq!(l, expected_lhs_type); - assert_eq!(r, expected_rhs_type); - - let lhs_type = l.unwrap_or(lhs_type); - let rhs_type = r.unwrap_or(rhs_type); + let (lhs_type, rhs_type) = math_decimal_coercion(&lhs_type, &rhs_type).unwrap(); + assert_eq!(lhs_type, expected_lhs_type); + assert_eq!(rhs_type, expected_rhs_type); // The coerced type of decimal math expression, applied during expression evaluation let coerced_type = @@ -1264,8 +1236,8 @@ mod tests { DataType::Decimal128(10, 2), DataType::Decimal128(10, 2), Operator::Plus, - None, - None, + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 2), Some(DataType::Decimal128(11, 2)), DataType::Decimal128(11, 2), ); @@ -1274,8 +1246,8 @@ mod tests { DataType::Int32, DataType::Decimal128(10, 2), Operator::Plus, - Some(DataType::Decimal128(10, 0)), - None, + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), Some(DataType::Decimal128(13, 2)), DataType::Decimal128(13, 2), ); @@ -1284,8 +1256,8 @@ mod tests { DataType::Int32, DataType::Decimal128(10, 2), Operator::Minus, - Some(DataType::Decimal128(10, 0)), - None, + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), Some(DataType::Decimal128(13, 2)), DataType::Decimal128(13, 2), ); @@ -1294,8 +1266,8 @@ mod tests { DataType::Int32, DataType::Decimal128(10, 2), Operator::Multiply, - Some(DataType::Decimal128(10, 0)), - None, + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), None, DataType::Decimal128(21, 2), ); @@ -1304,8 +1276,8 @@ mod tests { DataType::Int32, DataType::Decimal128(10, 2), Operator::Divide, - Some(DataType::Decimal128(10, 0)), - None, + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), Some(DataType::Decimal128(12, 2)), DataType::Decimal128(23, 11), ); @@ -1314,8 +1286,8 @@ mod tests { DataType::Int32, DataType::Decimal128(10, 2), Operator::Modulo, - Some(DataType::Decimal128(10, 0)), - None, + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), Some(DataType::Decimal128(12, 2)), DataType::Decimal128(10, 2), ); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 412abbfae644..1d9422b8cb19 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -32,7 +32,7 @@ use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ - any_decimal, coerce_types, comparison_coercion, like_coercion, math_decimal_coercion, + comparison_coercion, get_input_types, like_coercion, }; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::type_coercion::other::{ @@ -230,72 +230,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let expr = Expr::ILike(Like::new(negated, expr, pattern, escape_char)); Ok(expr) } - Expr::BinaryExpr(BinaryExpr { - ref left, - op, - ref right, - }) => { - // this is a workaround for https://github.com/apache/arrow-datafusion/issues/3419 - let left_type = left.get_type(&self.schema)?; - let right_type = right.get_type(&self.schema)?; - match (&left_type, &right_type) { - // Handle some case about Interval. - ( - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _), - &DataType::Interval(_), - ) if matches!(op, Operator::Plus | Operator::Minus) => Ok(expr), - ( - &DataType::Interval(_), - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _), - ) if matches!(op, Operator::Plus) => Ok(expr), - (DataType::Timestamp(_, _), DataType::Timestamp(_, _)) - if op.is_numerical_operators() => - { - if matches!(op, Operator::Minus) { - Ok(expr) - } else { - Err(DataFusionError::Internal(format!( - "Unsupported operation {op:?} between {left_type:?} and {right_type:?}" - ))) - } - } - // For numerical operations between decimals, we don't coerce the types. - // But if only one of the operands is decimal, we cast the other operand to decimal - // if the other operand is integer. If the other operand is float, we cast the - // decimal operand to float. - (lhs_type, rhs_type) - if op.is_numerical_operators() - && any_decimal(lhs_type, rhs_type) => - { - let (coerced_lhs_type, coerced_rhs_type) = - math_decimal_coercion(lhs_type, rhs_type); - let new_left = if let Some(lhs_type) = coerced_lhs_type { - left.clone().cast_to(&lhs_type, &self.schema)? - } else { - left.as_ref().clone() - }; - let new_right = if let Some(rhs_type) = coerced_rhs_type { - right.clone().cast_to(&rhs_type, &self.schema)? - } else { - right.as_ref().clone() - }; - let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(new_left), - op, - Box::new(new_right), - )); - Ok(expr) - } - _ => { - let common_type = coerce_types(&left_type, &op, &right_type)?; - let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.clone().cast_to(&common_type, &self.schema)?), - op, - Box::new(right.clone().cast_to(&common_type, &self.schema)?), - )); - Ok(expr) - } - } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + let lhs = left.get_type(&self.schema)?; + let rhs = right.get_type(&self.schema)?; + let (lhs, rhs) = get_input_types(&lhs, &op, &rhs)?; + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left.cast_to(&lhs, &self.schema)?), + op, + Box::new(right.cast_to(&rhs, &self.schema)?), + ))) } Expr::Between(Between { expr, @@ -566,7 +509,7 @@ fn coerce_window_frame( // The above op will be rewrite to the binary op when creating the physical op. fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result { let left_type = expr.get_type(schema)?; - coerce_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; + get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; expr.clone().cast_to(&DataType::Boolean, schema) } @@ -1108,9 +1051,9 @@ mod test { let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, ""); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains("Int64 IS DISTINCT FROM Boolean can't be evaluated because there isn't a common type to coerce the types to")); + let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, ""); + let err = ret.unwrap_err().to_string(); + assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); // is not true let expr = col("a").is_not_true(); @@ -1210,9 +1153,9 @@ mod test { let empty = empty_with_type(DataType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains("Utf8 IS DISTINCT FROM Boolean can't be evaluated because there isn't a common type to coerce the types to")); + let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); + let err = ret.unwrap_err().to_string(); + assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}"); // is not unknown let expr = col("a").is_not_unknown(); diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 8e9e361596b0..67238baa2dbb 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -1334,7 +1334,9 @@ mod tests { ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef, }; use datafusion_common::{ColumnStatistics, Result, Statistics}; - use datafusion_expr::type_coercion::binary::{coerce_types, math_decimal_coercion}; + use datafusion_expr::type_coercion::binary::{ + get_input_types, math_decimal_coercion, + }; // Create a binary expression without coercion. Used here when we do not want to coerce the expressions // to valid types. Usage can result in an execution (after plan) error. @@ -1438,10 +1440,10 @@ mod tests { ]); let a = $A_ARRAY::from($A_VEC); let b = $B_ARRAY::from($B_VEC); - let common_type = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?; + let (lhs, rhs) = get_input_types(&$A_TYPE, &$OP, &$B_TYPE)?; - let left = try_cast(col("a", &schema)?, &schema, common_type.clone())?; - let right = try_cast(col("b", &schema)?, &schema, common_type)?; + let left = try_cast(col("a", &schema)?, &schema, lhs)?; + let right = try_cast(col("b", &schema)?, &schema, rhs)?; // verify that we can construct the expression let expression = binary(left, $OP, right, &schema)?; @@ -2955,10 +2957,10 @@ mod tests { ) -> Result<()> { let left_type = left.data_type(); let right_type = right.data_type(); - let common_type = coerce_types(left_type, &op, right_type)?; + let (lhs, rhs) = get_input_types(left_type, &op, right_type)?; - let left_expr = try_cast(col("a", schema)?, schema, common_type.clone())?; - let right_expr = try_cast(col("b", schema)?, schema, common_type)?; + let left_expr = try_cast(col("a", schema)?, schema, lhs)?; + let right_expr = try_cast(col("b", schema)?, schema, rhs)?; let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; @@ -2977,17 +2979,10 @@ mod tests { expected: &BooleanArray, ) -> Result<()> { let scalar = lit(scalar.clone()); - let op_type = coerce_types(&scalar.data_type(schema)?, &op, arr.data_type())?; - let left_expr = if op_type.eq(&scalar.data_type(schema)?) { - scalar - } else { - try_cast(scalar, schema, op_type.clone())? - }; - let right_expr = if op_type.eq(arr.data_type()) { - col("a", schema)? - } else { - try_cast(col("a", schema)?, schema, op_type)? - }; + let (lhs, rhs) = + get_input_types(&scalar.data_type(schema)?, &op, arr.data_type())?; + let left_expr = try_cast(scalar, schema, lhs)?; + let right_expr = try_cast(col("a", schema)?, schema, rhs)?; let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; @@ -3006,17 +3001,10 @@ mod tests { expected: &BooleanArray, ) -> Result<()> { let scalar = lit(scalar.clone()); - let op_type = coerce_types(arr.data_type(), &op, &scalar.data_type(schema)?)?; - let right_expr = if op_type.eq(&scalar.data_type(schema)?) { - scalar - } else { - try_cast(scalar, schema, op_type.clone())? - }; - let left_expr = if op_type.eq(arr.data_type()) { - col("a", schema)? - } else { - try_cast(col("a", schema)?, schema, op_type)? - }; + let (lhs, rhs) = + get_input_types(arr.data_type(), &op, &scalar.data_type(schema)?)?; + let left_expr = try_cast(col("a", schema)?, schema, lhs)?; + let right_expr = try_cast(scalar, schema, rhs)?; let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; @@ -4068,26 +4056,11 @@ mod tests { op: Operator, expected: ArrayRef, ) -> Result<()> { - let (lhs_op_type, rhs_op_type) = - math_decimal_coercion(left.data_type(), right.data_type()); - - let (left_expr, lhs_type) = if let Some(lhs_op_type) = lhs_op_type { - ( - try_cast(col("a", schema)?, schema, lhs_op_type.clone())?, - lhs_op_type, - ) - } else { - (col("a", schema)?, left.data_type().clone()) - }; + let (lhs_type, rhs_type) = + math_decimal_coercion(left.data_type(), right.data_type()).unwrap(); - let (right_expr, rhs_type) = if let Some(rhs_op_type) = rhs_op_type { - ( - try_cast(col("b", schema)?, schema, rhs_op_type.clone())?, - rhs_op_type, - ) - } else { - (col("b", schema)?, right.data_type().clone()) - }; + let left_expr = try_cast(col("a", schema)?, schema, lhs_type.clone())?; + let right_expr = try_cast(col("b", schema)?, schema, rhs_type.clone())?; let coerced_schema = Schema::new(vec![ Field::new(