-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move Coercion for MakeArray to coerce_arguments_for_signature
and introduce another one for ArrayAppend
#8317
Changes from all commits
3ba90b4
32c9931
f760de5
6384a05
3686a2f
3226add
1045066
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,10 @@ use arrow::{ | |
compute::can_cast_types, | ||
datatypes::{DataType, TimeUnit}, | ||
}; | ||
use datafusion_common::{plan_err, DataFusionError, Result}; | ||
use datafusion_common::utils::list_ndims; | ||
use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; | ||
|
||
use super::binary::comparison_coercion; | ||
|
||
/// Performs type coercion for function arguments. | ||
/// | ||
|
@@ -86,16 +89,66 @@ fn get_valid_types( | |
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) | ||
.collect(), | ||
TypeSignature::VariadicEqual => { | ||
// one entry with the same len as current_types, whose type is `current_types[0]`. | ||
vec![current_types | ||
.iter() | ||
.map(|_| current_types[0].clone()) | ||
.collect()] | ||
let new_type = current_types.iter().skip(1).try_fold( | ||
current_types.first().unwrap().clone(), | ||
|acc, x| { | ||
let coerced_type = comparison_coercion(&acc, x); | ||
if let Some(coerced_type) = coerced_type { | ||
Ok(coerced_type) | ||
} else { | ||
internal_err!("Coercion from {acc:?} to {x:?} failed.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We call unwrap_or previously so |
||
} | ||
}, | ||
); | ||
|
||
match new_type { | ||
Ok(new_type) => vec![vec![new_type; current_types.len()]], | ||
Err(e) => return Err(e), | ||
} | ||
} | ||
TypeSignature::VariadicAny => { | ||
vec![current_types.to_vec()] | ||
} | ||
|
||
TypeSignature::Exact(valid_types) => vec![valid_types.clone()], | ||
TypeSignature::ArrayAndElement => { | ||
if current_types.len() != 2 { | ||
return Ok(vec![vec![]]); | ||
} | ||
|
||
let array_type = ¤t_types[0]; | ||
let elem_type = ¤t_types[1]; | ||
|
||
// We follow Postgres on `array_append(Null, T)`, which is not valid. | ||
if array_type.eq(&DataType::Null) { | ||
return Ok(vec![vec![]]); | ||
} | ||
|
||
// We need to find the coerced base type, mainly for cases like: | ||
// `array_append(List(null), i64)` -> `List(i64)` | ||
let array_base_type = datafusion_common::utils::base_type(array_type); | ||
let elem_base_type = datafusion_common::utils::base_type(elem_type); | ||
let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); | ||
|
||
if new_base_type.is_none() { | ||
return internal_err!( | ||
"Coercion from {array_base_type:?} to {elem_base_type:?} not supported." | ||
); | ||
} | ||
let new_base_type = new_base_type.unwrap(); | ||
|
||
let array_type = datafusion_common::utils::coerced_type_with_base_type_only( | ||
array_type, | ||
&new_base_type, | ||
); | ||
|
||
if let DataType::List(ref field) = array_type { | ||
let elem_type = field.data_type(); | ||
return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]); | ||
} else { | ||
return Ok(vec![vec![]]); | ||
} | ||
} | ||
TypeSignature::Any(number) => { | ||
if current_types.len() != *number { | ||
return plan_err!( | ||
|
@@ -241,6 +294,15 @@ fn coerced_from<'a>( | |
Utf8 | LargeUtf8 => Some(type_into.clone()), | ||
Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), | ||
|
||
// Only accept list with the same number of dimensions unless the type is Null. | ||
// List with different dimensions should be handled in TypeSignature or other places before this. | ||
List(_) | ||
if datafusion_common::utils::base_type(type_from).eq(&Null) | ||
|| list_ndims(type_from) == list_ndims(type_into) => | ||
{ | ||
Some(type_into.clone()) | ||
} | ||
|
||
Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => { | ||
match type_from { | ||
Timestamp(_, Some(from_tz)) => { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -590,26 +590,6 @@ fn coerce_arguments_for_fun( | |
.collect::<Result<Vec<_>>>()?; | ||
} | ||
|
||
if *fun == BuiltinScalarFunction::MakeArray { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❤️ |
||
// Find the final data type for the function arguments | ||
let current_types = expressions | ||
.iter() | ||
.map(|e| e.get_type(schema)) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
let new_type = current_types | ||
.iter() | ||
.skip(1) | ||
.fold(current_types.first().unwrap().clone(), |acc, x| { | ||
comparison_coercion(&acc, x).unwrap_or(acc) | ||
}); | ||
|
||
return expressions | ||
.iter() | ||
.zip(current_types) | ||
.map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) | ||
.collect(); | ||
} | ||
Ok(expressions) | ||
} | ||
|
||
|
@@ -618,20 +598,6 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result<Expr> | |
expr.clone().cast_to(to_type, schema) | ||
} | ||
|
||
/// Cast array `expr` to the specified type, if possible | ||
fn cast_array_expr( | ||
expr: &Expr, | ||
from_type: &DataType, | ||
to_type: &DataType, | ||
schema: &DFSchema, | ||
) -> Result<Expr> { | ||
if from_type.equals_datatype(&DataType::Null) { | ||
Ok(expr.clone()) | ||
} else { | ||
cast_expr(expr, to_type, schema) | ||
} | ||
} | ||
|
||
/// Returns the coerced exprs for each `input_exprs`. | ||
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the | ||
/// data type of `input_exprs` need to be coerced. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -297,10 +297,8 @@ AS VALUES | |
(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) | ||
; | ||
|
||
query ? | ||
query error | ||
select [1, true, null] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is an error because
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
---- | ||
[1, 1, ] | ||
|
||
query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() | ||
SELECT [now()] | ||
|
@@ -1253,18 +1251,43 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3 | |
|
||
## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) | ||
|
||
# TODO: array_append with NULLs | ||
# array_append scalar function #1 | ||
# query ? | ||
# select array_append(make_array(), 4); | ||
# ---- | ||
# [4] | ||
# array_append with NULLs | ||
|
||
# array_append scalar function #2 | ||
# query ?? | ||
# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); | ||
# ---- | ||
# [[]] [[4]] | ||
query error | ||
select array_append(null, 1); | ||
|
||
query error | ||
select array_append(null, [2, 3]); | ||
|
||
query error | ||
select array_append(null, [[4]]); | ||
|
||
query ???? | ||
select | ||
array_append(make_array(), 4), | ||
array_append(make_array(), null), | ||
array_append(make_array(1, null, 3), 4), | ||
array_append(make_array(null, null), 1) | ||
; | ||
---- | ||
[4] [] [1, , 3, 4] [, , 1] | ||
|
||
# test invalid (non-null) | ||
query error | ||
select array_append(1, 2); | ||
|
||
query error | ||
select array_append(1, [2]); | ||
|
||
query error | ||
select array_append([1], [2]); | ||
|
||
query ?? | ||
select | ||
array_append(make_array(make_array(1, null, 3)), make_array(null)), | ||
array_append(make_array(make_array(1, null, 3)), null); | ||
---- | ||
[[1, , 3], []] [[1, , 3], ] | ||
|
||
# array_append scalar function #3 | ||
query ??? | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about using a more generic name. Perhaps something like
There may be more general ways of expressing the array function types too 🤔