diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 10d8306d767cb..615bb3ac568c2 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -17,6 +17,7 @@ //! Coercion rules for matching argument types for binary operators +use std::collections::HashSet; use std::sync::Arc; use crate::Operator; @@ -289,13 +290,207 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option for TypeCategory { + fn from(data_type: &DataType) -> Self { + match data_type { + // Dict is a special type in arrow, we check the value type + DataType::Dictionary(_, v) => { + let v = v.as_ref(); + TypeCategory::from(v) + } + _ => { + if data_type.is_numeric() { + return TypeCategory::Numeric; + } + + if matches!(data_type, DataType::Boolean) { + return TypeCategory::Boolean; + } + + if matches!( + data_type, + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + ) { + return TypeCategory::Array; + } + + // String literal is possible to cast to many other types like numeric or datetime, + // therefore, it is categorized as a unknown type + if matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) { + return TypeCategory::Unknown; + } + + if matches!( + data_type, + DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Interval(_) + | DataType::Duration(_) + ) { + return TypeCategory::DateTime; + } + + if matches!( + data_type, + DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _) + ) { + return TypeCategory::Composite; + } + + TypeCategory::NotSupported + } + } + } +} + +/// Coerce dissimilar data types to a single data type. +/// UNION, INTERSECT, EXCEPT, CASE, ARRAY, VALUES, and the GREATEST and LEAST functions are +/// examples that has the similar resolution rules. +/// See for more information. +/// The rules in the document provide a clue, but adhering strictly to them doesn't precisely +/// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules +/// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted +/// decimal percision and scale when coercing decimal types. +pub fn type_union_resolution(data_types: &[DataType]) -> Option { + if data_types.is_empty() { + return None; + } + + // if all the data_types is the same return first one + if data_types.iter().all(|t| t == &data_types[0]) { + return Some(data_types[0].clone()); + } + + // if all the data_types are null, return string + if data_types.iter().all(|t| t == &DataType::Null) { + return Some(DataType::Utf8); + } + + // Ignore Nulls, if any data_type category is not the same, return None + let data_types_category: Vec = data_types + .iter() + .filter(|&t| t != &DataType::Null) + .map(|t| t.into()) + .collect(); + + if data_types_category + .iter() + .any(|t| t == &TypeCategory::NotSupported) + { + return None; + } + + // check if there is only one category excluding Unknown + let categories: HashSet = HashSet::from_iter( + data_types_category + .iter() + .filter(|&c| c != &TypeCategory::Unknown) + .cloned(), + ); + if categories.len() > 1 { + return None; + } + + // Ignore Nulls + let mut candidate_type: Option = None; + for data_type in data_types.iter() { + if data_type == &DataType::Null { + continue; + } + if let Some(ref candidate_t) = candidate_type { + // Find candidate type that all the data types can be coerced to + // Follows the behavior of Postgres and DuckDB + // Coerced type may be different from the candidate and current data type + // For example, + // i64 and decimal(7, 2) are expect to get coerced type decimal(22, 2) + // numeric string ('1') and numeric (2) are expect to get coerced type numeric (1, 2) + if let Some(t) = type_union_resolution_coercion(data_type, candidate_t) { + candidate_type = Some(t); + } else { + return None; + } + } else { + candidate_type = Some(data_type.clone()); + } + } + + candidate_type +} + +/// Coerce `lhs_type` and `rhs_type` to a common type for [type_union_resolution] +/// See [type_union_resolution] for more information. +fn type_union_resolution_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + if lhs_type == rhs_type { + return Some(lhs_type.clone()); + } + + match (lhs_type, rhs_type) { + ( + DataType::Dictionary(lhs_index_type, lhs_value_type), + DataType::Dictionary(rhs_index_type, rhs_value_type), + ) => { + let new_index_type = + type_union_resolution_coercion(lhs_index_type, rhs_index_type); + let new_value_type = + type_union_resolution_coercion(lhs_value_type, rhs_value_type); + if let (Some(new_index_type), Some(new_value_type)) = + (new_index_type, new_value_type) + { + Some(DataType::Dictionary( + Box::new(new_index_type), + Box::new(new_value_type), + )) + } else { + None + } + } + (DataType::Dictionary(index_type, value_type), other_type) + | (other_type, DataType::Dictionary(index_type, value_type)) => { + let new_value_type = type_union_resolution_coercion(value_type, other_type); + new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) + } + _ => { + // numeric coercion is the same as comparison coercion, both find the narrowest type + // that can accommodate both types + binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) + } + } +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation +/// Unlike `coerced_from`, usually the coerced type is for comparison only. +/// For example, compare with Dictionary and Dictionary, only value type is what we care about pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { if lhs_type == rhs_type { // same type => equality is possible return Some(lhs_type.clone()); } - comparison_binary_numeric_coercion(lhs_type, rhs_type) + binary_numeric_coercion(lhs_type, rhs_type) .or_else(|| dictionary_coercion(lhs_type, rhs_type, true)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) @@ -312,7 +507,7 @@ pub fn values_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option equality is possible return Some(lhs_type.clone()); } - comparison_binary_numeric_coercion(lhs_type, rhs_type) + binary_numeric_coercion(lhs_type, rhs_type) .or_else(|| temporal_coercion(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| binary_coercion(lhs_type, rhs_type)) @@ -372,9 +567,8 @@ fn string_temporal_coercion( match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type)) } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation -/// where one both are numeric -pub(crate) fn comparison_binary_numeric_coercion( +/// Coerce `lhs_type` and `rhs_type` to a common type where both are numeric +pub(crate) fn binary_numeric_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { @@ -388,20 +582,13 @@ pub(crate) fn comparison_binary_numeric_coercion( return Some(lhs_type.clone()); } + if let Some(t) = decimal_coercion(lhs_type, rhs_type) { + return Some(t); + } + // these are ordered from most informative to least informative so // that the coercion does not lose information via truncation match (lhs_type, rhs_type) { - // Prefer decimal data type over floating point for comparison operation - (Decimal128(_, _), Decimal128(_, _)) => { - get_wider_decimal_type(lhs_type, rhs_type) - } - (Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), - (_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), - (Decimal256(_, _), Decimal256(_, _)) => { - get_wider_decimal_type(lhs_type, rhs_type) - } - (Decimal256(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), - (_, Decimal256(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), (Float64, _) | (_, Float64) => Some(Float64), (_, Float32) | (Float32, _) => Some(Float32), // The following match arms encode the following logic: Given the two @@ -409,6 +596,11 @@ pub(crate) fn comparison_binary_numeric_coercion( // accommodates all values of both types. Note that some information // loss is inevitable when we have a signed type and a `UInt64`, in // which case we use `Int64`;i.e. the widest signed integral type. + + // TODO: For i64 and u64, we can use decimal or float64 + // Postgres has no unsigned type :( + // DuckDB v.0.10.0 has double (double precision floating-point number (8 bytes)) + // for largest signed (signed sixteen-byte integer) and unsigned integer (unsigned sixteen-byte integer) (Int64, _) | (_, Int64) | (UInt64, Int8) @@ -439,9 +631,28 @@ pub(crate) fn comparison_binary_numeric_coercion( } } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of -/// a comparison operation where one is a decimal -fn get_comparison_common_decimal_type( +/// Decimal coercion rules. +pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + + match (lhs_type, rhs_type) { + // Prefer decimal data type over floating point for comparison operation + (Decimal128(_, _), Decimal128(_, _)) => { + get_wider_decimal_type(lhs_type, rhs_type) + } + (Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), + (_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type), + (Decimal256(_, _), Decimal256(_, _)) => { + get_wider_decimal_type(lhs_type, rhs_type) + } + (Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), + (_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type), + (_, _) => None, + } +} + +/// Coerce `lhs_type` and `rhs_type` to a common type. +fn get_common_decimal_type( decimal_type: &DataType, other_type: &DataType, ) -> Option { @@ -725,6 +936,18 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } +fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Utf8 | LargeUtf8, other_type) | (other_type, Utf8 | LargeUtf8) + if other_type.is_numeric() => + { + Some(other_type.clone()) + } + _ => None, + } +} + /// Coercion rules for list types. fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index b41ec109103d1..623fa25742600 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -30,7 +30,7 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, Result, }; -use super::binary::{comparison_binary_numeric_coercion, comparison_coercion}; +use super::binary::comparison_coercion; /// Performs type coercion for scalar function arguments. /// @@ -66,20 +66,7 @@ pub fn data_types_with_scalar_udf( return Ok(current_types.to_vec()); } - // Try and coerce the argument types to match the signature, returning the - // coerced types from the first matching signature. - for valid_types in valid_types { - if let Some(types) = maybe_data_types(&valid_types, current_types) { - return Ok(types); - } - } - - // none possible -> Error - plan_err!( - "[data_types_with_scalar_udf] Coercion from {:?} to the signature {:?} failed.", - current_types, - &signature.type_signature - ) + try_coerce_types(valid_types, current_types, &signature.type_signature) } pub fn data_types_with_aggregate_udf( @@ -112,20 +99,7 @@ pub fn data_types_with_aggregate_udf( return Ok(current_types.to_vec()); } - // Try and coerce the argument types to match the signature, returning the - // coerced types from the first matching signature. - for valid_types in valid_types { - if let Some(types) = maybe_data_types(&valid_types, current_types) { - return Ok(types); - } - } - - // none possible -> Error - plan_err!( - "[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.", - current_types, - &signature.type_signature - ) + try_coerce_types(valid_types, current_types, &signature.type_signature) } /// Performs type coercion for function arguments. @@ -152,7 +126,6 @@ pub fn data_types( } let valid_types = get_valid_types(&signature.type_signature, current_types)?; - if valid_types .iter() .any(|data_type| data_type == current_types) @@ -160,19 +133,39 @@ pub fn data_types( return Ok(current_types.to_vec()); } - // Try and coerce the argument types to match the signature, returning the - // coerced types from the first matching signature. - for valid_types in valid_types { - if let Some(types) = maybe_data_types(&valid_types, current_types) { - return Ok(types); + try_coerce_types(valid_types, current_types, &signature.type_signature) +} + +fn try_coerce_types( + valid_types: Vec>, + current_types: &[DataType], + type_signature: &TypeSignature, +) -> Result> { + let mut valid_types = valid_types; + + // Well-supported signature that returns exact valid types. + if !valid_types.is_empty() && matches!(type_signature, TypeSignature::UserDefined) { + // exact valid types + assert_eq!(valid_types.len(), 1); + let valid_types = valid_types.swap_remove(0); + if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) { + return Ok(t); + } + } else { + // Try and coerce the argument types to match the signature, returning the + // coerced types from the first matching signature. + for valid_types in valid_types { + if let Some(types) = maybe_data_types(&valid_types, current_types) { + return Ok(types); + } } } // none possible -> Error plan_err!( - "[data_types] Coercion from {:?} to the signature {:?} failed.", + "Coercion from {:?} to the signature {:?} failed.", current_types, - &signature.type_signature + type_signature ) } @@ -438,6 +431,8 @@ fn maybe_data_types( new_type.push(current_type.clone()) } else { // attempt to coerce. + // TODO: Replace with `can_cast_types` after failing cases are resolved + // (they need new signature that returns exactly valid types instead of list of possible valid types). if let Some(coerced_type) = coerced_from(valid_type, current_type) { new_type.push(coerced_type) } else { @@ -449,6 +444,33 @@ fn maybe_data_types( Some(new_type) } +/// Check if the current argument types can be coerced to match the given `valid_types` +/// unlike `maybe_data_types`, this function does not coerce the types. +/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported. +fn maybe_data_types_without_coercion( + valid_types: &[DataType], + current_types: &[DataType], +) -> Option> { + if valid_types.len() != current_types.len() { + return None; + } + + let mut new_type = Vec::with_capacity(valid_types.len()); + for (i, valid_type) in valid_types.iter().enumerate() { + let current_type = ¤t_types[i]; + + if current_type == valid_type { + new_type.push(current_type.clone()) + } else if can_cast_types(current_type, valid_type) { + // validate the valid type is castable from the current type + new_type.push(valid_type.clone()) + } else { + return None; + } + } + Some(new_type) +} + /// Return true if a value of type `type_from` can be coerced /// (losslessly converted) into a value of `type_to` /// @@ -463,11 +485,18 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { false } +/// Find the coerced type for the given `type_into` and `type_from`. +/// Returns `None` if coercion is not possible. +/// +/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32. +/// +/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion. fn coerced_from<'a>( type_into: &'a DataType, type_from: &'a DataType, ) -> Option { use self::DataType::*; + // match Dictionary first match (type_into, type_from) { // coerced dictionary first @@ -585,7 +614,6 @@ fn coerced_from<'a>( } _ => None, }, - (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { @@ -606,19 +634,7 @@ fn coerced_from<'a>( { Some(type_into.clone()) } - // More coerce rules. - // Note that not all rules in `comparison_coercion` can be reused here. - // For example, all numeric types can be coerced into Utf8 for comparison, - // but not for function arguments. - _ => comparison_binary_numeric_coercion(type_into, type_from).and_then( - |coerced_type| { - if *type_into == coerced_type { - Some(coerced_type) - } else { - None - } - }, - ), + _ => None, } } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 63778eb7738ac..15a3ddd9d6e9d 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -22,8 +22,8 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; -use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::type_coercion::binary::comparison_coercion; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::type_coercion::binary::type_union_resolution; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -124,21 +124,11 @@ impl ScalarUDFImpl for CoalesceFunc { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let new_type = arg_types.iter().skip(1).try_fold( - arg_types.first().unwrap().clone(), - |acc, x| { - // The coerced types found by `comparison_coercion` are not guaranteed to be - // coercible for the arguments. `comparison_coercion` returns more loose - // types that can be coerced to both `acc` and `x` for comparison purpose. - // See `maybe_data_types` for the actual coercion. - let coerced_type = comparison_coercion(&acc, x); - if let Some(coerced_type) = coerced_type { - Ok(coerced_type) - } else { - internal_err!("Coercion from {acc:?} to {x:?} failed.") - } - }, - )?; + if arg_types.is_empty() { + return exec_err!("coalesce must have at least one argument"); + } + let new_type = type_union_resolution(arg_types) + .unwrap_or(arg_types.first().unwrap().clone()); Ok(vec![new_type; arg_types.len()]) } } diff --git a/datafusion/sqllogictest/test_files/coalesce.slt b/datafusion/sqllogictest/test_files/coalesce.slt index a0317ac4a5f4e..17b0e774d9cb7 100644 --- a/datafusion/sqllogictest/test_files/coalesce.slt +++ b/datafusion/sqllogictest/test_files/coalesce.slt @@ -208,7 +208,7 @@ select ---- [3, 4] List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) -# TODO: after switch signature of array to the same with coalesce, this query should be fixed +# coalesce with array query ?T select coalesce(array[1, 2], array[arrow_cast(3, 'Int32'), arrow_cast(4, 'Int32')]), @@ -216,10 +216,10 @@ select ---- [1, 2] List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +# test dict(int32, utf8) statement ok create table test1 as values (arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (null); -# Dictionary and String are not coercible query ? select coalesce(column1, 'none_set') from test1; ---- @@ -242,7 +242,59 @@ none_set statement ok drop table test1 -# Numeric and Dictionary are not coercible +# test dict coercion with value +statement ok +create table t(c varchar) as values ('a'), (null); + +query TT +select + coalesce(c, arrow_cast('b', 'Dictionary(Int32, Utf8)')), + arrow_typeof(coalesce(c, arrow_cast('b', 'Dictionary(Int32, Utf8)'))) +from t; +---- +a Dictionary(Int32, Utf8) +b Dictionary(Int32, Utf8) + +statement ok +drop table t; + +# test dict coercion with dict +statement ok +create table t as values + (arrow_cast('foo', 'Dictionary(Int32, Utf8)')), + (null); + +query ?T +select + coalesce(column1, arrow_cast('bar', 'Dictionary(Int64, LargeUtf8)')), + arrow_typeof(coalesce(column1, arrow_cast('bar', 'Dictionary(Int64, LargeUtf8)'))) +from t; +---- +foo Dictionary(Int64, LargeUtf8) +bar Dictionary(Int64, LargeUtf8) + +query ?T +select + coalesce(column1, arrow_cast('bar', 'Dictionary(Int32, LargeUtf8)')), + arrow_typeof(coalesce(column1, arrow_cast('bar', 'Dictionary(Int32, LargeUtf8)'))) +from t; +---- +foo Dictionary(Int32, LargeUtf8) +bar Dictionary(Int32, LargeUtf8) + +query ?T +select + coalesce(column1, arrow_cast('bar', 'Dictionary(Int64, Utf8)')), + arrow_typeof(coalesce(column1, arrow_cast('bar', 'Dictionary(Int64, Utf8)'))) +from t; +---- +foo Dictionary(Int64, Utf8) +bar Dictionary(Int64, Utf8) + +statement ok +drop table t; + +# test dict(int32, int8) query I select coalesce(34, arrow_cast(123, 'Dictionary(Int32, Int8)')); ---- @@ -258,6 +310,12 @@ select coalesce(null, 34, arrow_cast(123, 'Dictionary(Int32, Int8)')); ---- 34 +# numeric string coercion +query RT +select coalesce(2.0, 1, '3'), arrow_typeof(coalesce(2.0, 1, '3')); +---- +2 Float64 + # explicitly cast to Int8, and it will implicitly cast to Int64 query IT select