Skip to content

Commit

Permalink
fix array_append with null array
Browse files Browse the repository at this point in the history
Signed-off-by: veeupup <code@tanweime.com>
  • Loading branch information
Veeupup committed Nov 27, 2023
1 parent d81c961 commit fd94178
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 32 deletions.
18 changes: 17 additions & 1 deletion datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,23 @@ impl BuiltinScalarFunction {
let data_type = get_base_type(&input_expr_types[0])?;
Ok(data_type)
}
BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayAppend => {
if let DataType::List(field) = input_expr_types[0].clone() {
if field.data_type().equals_datatype(&DataType::Null) {
Ok(DataType::List(Arc::new(Field::new(
"item",
input_expr_types[1].clone(),
true,
))))
} else {
Ok(input_expr_types[0].clone())
}
} else {
plan_err!(
"The {self} function can only accept list as the first argument"
)
}
}
BuiltinScalarFunction::ArrayConcat => {
let mut expr_type = Null;
let mut max_dims = 0;
Expand Down
45 changes: 22 additions & 23 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,29 @@ fn compute_array_dims(arr: Option<ArrayRef>) -> Result<Option<Vec<Option<u64>>>>
}

fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
let data_type = args[0].data_type();
if !args.iter().all(|arg| {
arg.data_type().equals_datatype(data_type)
|| arg.data_type().equals_datatype(&DataType::Null)
}) {
let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
return plan_err!("{name} received incompatible types: '{types:?}'.");
let mut data_types = args
.iter()
.map(|arg| arg.data_type())
.collect::<HashSet<_>>();
match data_types.len() {
1 => Ok(()),
2 => {
if data_types.remove(&DataType::Null) {
Ok(())
} else {
Err(DataFusionError::Plan(format!(
"{name} received incompatible types: '{types:?}'.",
name = name,
types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>()
)))
}
}
_ => Err(DataFusionError::Plan(format!(
"{name} received incompatible types: '{types:?}'.",
name = name,
types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>()
))),
}

Ok(())
}

macro_rules! call_array_function {
Expand Down Expand Up @@ -2951,20 +2964,6 @@ mod tests {
assert_eq!(result, &UInt64Array::from_value(2, 1));
}

#[test]
fn test_check_invalid_datatypes() {
let data = vec![Some(vec![Some(1), Some(2), Some(3)])];
let list_array =
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
let int64_array = Arc::new(StringArray::from(vec![Some("string")])) as ArrayRef;

let args = [list_array.clone(), int64_array.clone()];

let array = array_append(&args);

assert_eq!(array.unwrap_err().strip_backtrace(), "Error during planning: array_append received incompatible types: '[Int64, Utf8]'.");
}

fn return_array() -> ArrayRef {
// Returns: [1, 2, 3, 4]
let args = [
Expand Down
16 changes: 8 additions & 8 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1056,16 +1056,16 @@ select make_array(['a','b'], null);

# TODO: array_append with NULLs
# array_append scalar function #1
# query ?
# select array_append(make_array(), 4);
# ----
# [4]
query ?
select array_append(make_array(), 4);
----
[4]

# array_append scalar function #2
# query ??
# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4));
# ----
# [[]] [[4]]
query ??
select array_append(make_array(), make_array()), array_append(make_array(), make_array(4));
----
[[]] [[4]]

# array_append scalar function #3
query ???
Expand Down

0 comments on commit fd94178

Please sign in to comment.