Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Coercion for MakeArray to coerce_arguments_for_signature and introduce another one for ArrayAppend #8317

Merged
merged 7 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ pub fn longest_consecutive_prefix<T: Borrow<usize>>(
count
}

/// Array Utils

/// Wrap an array into a single element `ListArray`.
/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]`
pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
Expand Down Expand Up @@ -429,6 +431,42 @@ pub fn base_type(data_type: &DataType) -> DataType {
}
}

/// A helper function to coerce base type in List.
///
/// Example
/// ```
/// use arrow::datatypes::{DataType, Field};
/// use datafusion_common::utils::coerced_type_with_base_type_only;
/// use std::sync::Arc;
///
/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
/// let base_type = DataType::Float64;
/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type);
/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true))));
pub fn coerced_type_with_base_type_only(
data_type: &DataType,
base_type: &DataType,
) -> DataType {
match data_type {
DataType::List(field) => {
let data_type = match field.data_type() {
DataType::List(_) => {
coerced_type_with_base_type_only(field.data_type(), base_type)
}
_ => base_type.to_owned(),
};

DataType::List(Arc::new(Field::new(
field.name(),
data_type,
field.is_nullable(),
)))
}

_ => base_type.clone(),
}
}

/// Compute the number of dimensions in a list data type.
pub fn list_ndims(data_type: &DataType) -> u64 {
if let DataType::List(field) = data_type {
Expand Down
13 changes: 8 additions & 5 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,17 @@ impl BuiltinScalarFunction {

// for now, the list is small, as we do not have many built-in functions.
match self {
BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArraySort => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayAppend => Signature {
type_signature: ArrayAndElement,
volatility: self.volatility(),
},
BuiltinScalarFunction::MakeArray => {
// 0 or more arguments of arbitrary type
Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility())
}
BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayConcat => {
Expand Down Expand Up @@ -958,10 +965,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::MakeArray => {
// 0 or more arguments of arbitrary type
Signature::one_of(vec![VariadicAny, Any(0)], self.volatility())
}
BuiltinScalarFunction::Range => Signature::one_of(
vec![
Exact(vec![Int64]),
Expand Down
20 changes: 17 additions & 3 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,14 @@ pub enum TypeSignature {
/// DataFusion attempts to coerce all argument types to match the first argument's type
///
/// # Examples
/// A function such as `array` is `VariadicEqual`
/// Given types in signature should be coericible to the same final type.
/// A function such as `make_array` is `VariadicEqual`.
///
/// `make_array(i32, i64) -> make_array(i64, i64)`
VariadicEqual,
/// One or more arguments with arbitrary types
VariadicAny,
/// fixed number of arguments of an arbitrary but equal type out of a list of valid types.
/// Fixed number of arguments of an arbitrary but equal type out of a list of valid types.
///
/// # Examples
/// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])`
Expand All @@ -113,6 +116,12 @@ pub enum TypeSignature {
/// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature`
/// is `OneOf(vec![Any(0), VariadicAny])`.
OneOf(Vec<TypeSignature>),
/// Specialized Signature for ArrayAppend and similar functions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about using a more generic name. Perhaps something like

    /// The first argument is an array type ([`DataType::List`], or [`DataType::LargeList`]
    /// and the subsequent arguments are coerced to the List's element type
    ///
    /// For example a call to `func(a: List(int64), b: int32, c: utf8)` would attempt to coerce
    /// all the arguments to `int64`: 
    /// ```
    /// func(a: List(int64), cast(b as int64): int64, cast(c as int64): int64)
    /// ```
    ArrayAndElements

There may be more general ways of expressing the array function types too 🤔

/// The first argument should be List/LargeList, and the second argument should be non-list or list.
/// The second argument's list dimension should be one dimension less than the first argument's list dimension.
/// List dimension of the List/LargeList is equivalent to the number of List.
/// List dimension of the non-list is 0.
ArrayAndElement,
}

impl TypeSignature {
Expand All @@ -136,11 +145,16 @@ impl TypeSignature {
.collect::<Vec<&str>>()
.join(", ")]
}
TypeSignature::VariadicEqual => vec!["T, .., T".to_string()],
TypeSignature::VariadicEqual => {
vec!["CoercibleT, .., CoercibleT".to_string()]
}
TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()],
TypeSignature::OneOf(sigs) => {
sigs.iter().flat_map(|s| s.to_string_repr()).collect()
}
TypeSignature::ArrayAndElement => {
vec!["ArrayAndElement(List<T>, T)".to_string()]
}
}
}

Expand Down
74 changes: 68 additions & 6 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{plan_err, DataFusionError, Result};
use datafusion_common::utils::list_ndims;
use datafusion_common::{internal_err, plan_err, DataFusionError, Result};

use super::binary::comparison_coercion;

/// Performs type coercion for function arguments.
///
Expand Down Expand Up @@ -86,16 +89,66 @@ fn get_valid_types(
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicEqual => {
// one entry with the same len as current_types, whose type is `current_types[0]`.
vec![current_types
.iter()
.map(|_| current_types[0].clone())
.collect()]
let new_type = current_types.iter().skip(1).try_fold(
current_types.first().unwrap().clone(),
|acc, x| {
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {acc:?} to {x:?} failed.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call unwrap_or previously so select [1, true, null] unexpectedly correct since true is castable to 1 in arrow-rs but not in datafusion. select [true, 1, null] failed. It is better that we just return error if not coercible.

}
},
);

match new_type {
Ok(new_type) => vec![vec![new_type; current_types.len()]],
Err(e) => return Err(e),
}
}
TypeSignature::VariadicAny => {
vec![current_types.to_vec()]
}

TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::ArrayAndElement => {
if current_types.len() != 2 {
return Ok(vec![vec![]]);
}

let array_type = &current_types[0];
let elem_type = &current_types[1];

// We follow Postgres on `array_append(Null, T)`, which is not valid.
if array_type.eq(&DataType::Null) {
return Ok(vec![vec![]]);
}

// We need to find the coerced base type, mainly for cases like:
// `array_append(List(null), i64)` -> `List(i64)`
let array_base_type = datafusion_common::utils::base_type(array_type);
let elem_base_type = datafusion_common::utils::base_type(elem_type);
let new_base_type = comparison_coercion(&array_base_type, &elem_base_type);

if new_base_type.is_none() {
return internal_err!(
"Coercion from {array_base_type:?} to {elem_base_type:?} not supported."
);
}
let new_base_type = new_base_type.unwrap();

let array_type = datafusion_common::utils::coerced_type_with_base_type_only(
array_type,
&new_base_type,
);

if let DataType::List(ref field) = array_type {
let elem_type = field.data_type();
return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]);
} else {
return Ok(vec![vec![]]);
}
}
TypeSignature::Any(number) => {
if current_types.len() != *number {
return plan_err!(
Expand Down Expand Up @@ -241,6 +294,15 @@ fn coerced_from<'a>(
Utf8 | LargeUtf8 => Some(type_into.clone()),
Null if can_cast_types(type_from, type_into) => Some(type_into.clone()),

// Only accept list with the same number of dimensions unless the type is Null.
// List with different dimensions should be handled in TypeSignature or other places before this.
List(_)
if datafusion_common::utils::base_type(type_from).eq(&Null)
|| list_ndims(type_from) == list_ndims(type_into) =>
{
Some(type_into.clone())
}

Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => {
match type_from {
Timestamp(_, Some(from_tz)) => {
Expand Down
34 changes: 0 additions & 34 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,26 +590,6 @@ fn coerce_arguments_for_fun(
.collect::<Result<Vec<_>>>()?;
}

if *fun == BuiltinScalarFunction::MakeArray {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

// Find the final data type for the function arguments
let current_types = expressions
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

let new_type = current_types
.iter()
.skip(1)
.fold(current_types.first().unwrap().clone(), |acc, x| {
comparison_coercion(&acc, x).unwrap_or(acc)
});

return expressions
.iter()
.zip(current_types)
.map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema))
.collect();
}
Ok(expressions)
}

Expand All @@ -618,20 +598,6 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result<Expr>
expr.clone().cast_to(to_type, schema)
}

/// Cast array `expr` to the specified type, if possible
fn cast_array_expr(
expr: &Expr,
from_type: &DataType,
to_type: &DataType,
schema: &DFSchema,
) -> Result<Expr> {
if from_type.equals_datatype(&DataType::Null) {
Ok(expr.clone())
} else {
cast_expr(expr, to_type, schema)
}
}

/// Returns the coerced exprs for each `input_exprs`.
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the
/// data type of `input_exprs` need to be coerced.
Expand Down
25 changes: 8 additions & 17 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
match data_type {
// Either an empty array or all nulls:
DataType::Null => {
let array = new_null_array(&DataType::Null, arrays.len());
let array =
new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum());
Ok(Arc::new(array_into_list_array(array)))
}
DataType::LargeList(..) => array_array::<i64>(arrays, data_type),
Expand Down Expand Up @@ -827,10 +828,14 @@ pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
let element_array = &args[1];

check_datatypes("array_append", &[list_array.values(), element_array])?;
let res = match list_array.value_type() {
DataType::List(_) => concat_internal(args)?,
DataType::Null => return make_array(&[element_array.to_owned()]),
DataType::Null => {
return make_array(&[
list_array.values().to_owned(),
element_array.to_owned(),
]);
}
data_type => {
return general_append_and_prepend(
list_array,
Expand Down Expand Up @@ -2284,18 +2289,4 @@ mod tests {
expected_dim
);
}

#[test]
fn test_check_invalid_datatypes() {
let data = vec![Some(vec![Some(1), Some(2), Some(3)])];
let list_array =
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
let int64_array = Arc::new(StringArray::from(vec![Some("string")])) as ArrayRef;

let args = [list_array.clone(), int64_array.clone()];

let array = array_append(&args);

assert_eq!(array.unwrap_err().strip_backtrace(), "Error during planning: array_append received incompatible types: '[Int64, Utf8]'.");
}
}
51 changes: 37 additions & 14 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,8 @@ AS VALUES
(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10)
;

query ?
query error
select [1, true, null]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an error because true can't be coerced to an integer, right? FWIW I think that is fine and is consistent with the postgres rues:

postgres=# select array[1, true, null];
ERROR:  ARRAY types integer and boolean cannot be matched
LINE 1: select array[1, true, null];

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

----
[1, 1, ]

query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now()
SELECT [now()]
Expand Down Expand Up @@ -1253,18 +1251,43 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3

## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`)

# TODO: array_append with NULLs
# array_append scalar function #1
# query ?
# select array_append(make_array(), 4);
# ----
# [4]
# array_append with NULLs

# array_append scalar function #2
# query ??
# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4));
# ----
# [[]] [[4]]
query error
select array_append(null, 1);

query error
select array_append(null, [2, 3]);

query error
select array_append(null, [[4]]);

query ????
select
array_append(make_array(), 4),
array_append(make_array(), null),
array_append(make_array(1, null, 3), 4),
array_append(make_array(null, null), 1)
;
----
[4] [] [1, , 3, 4] [, , 1]

# test invalid (non-null)
query error
select array_append(1, 2);

query error
select array_append(1, [2]);

query error
select array_append([1], [2]);

query ??
select
array_append(make_array(make_array(1, null, 3)), make_array(null)),
array_append(make_array(make_array(1, null, 3)), null);
----
[[1, , 3], []] [[1, , 3], ]

# array_append scalar function #3
query ???
Expand Down