diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 42d455a05760..28bb0f1ba426 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, as_largestring_array}; -use arrow::datatypes::DataType; +use arrow::array::{Array, as_largestring_array, ArrayRef}; +use arrow::datatypes::{DataType, Field, TimeUnit}; +use arrow_array::types::*; use datafusion_expr::sort_properties::ExprProperties; use std::any::Any; use std::sync::Arc; @@ -26,10 +27,10 @@ use crate::strings::{ ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder, }; use datafusion_common::cast::{as_string_array, as_string_view_array}; -use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; -use datafusion_expr::expr::ScalarFunction; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err, exec_err}; +use datafusion_expr::expr::{ScalarFunction, ScalarFunctionExpr}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit, BuiltinScalarFunction}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -67,13 +68,42 @@ impl ConcatFunc { pub fn new() -> Self { use DataType::*; Self { - signature: Signature::variadic( - vec![Utf8View, Utf8, LargeUtf8], - Volatility::Immutable, + signature: Signature::variadic_any(Volatility::Immutable) ), } } -} + + // Check if any argument is an array type + fn has_array_args(args: &[ColumnarValue]) -> bool { + use arrow::datatypes::DataType::*; + args.iter().any(|arg| match arg { + ColumnarValue::Array(arr) => matches!(arr.data_type(), List(_)), + ColumnarValue::Scalar(scalar) => matches!(scalar, ScalarValue::List(_, _)), + }) + } + + // Convert arguments to array_concat function + fn to_array_concat(args: Vec) -> Result { + let array_concat = BuiltinScalarFunction::ArrayConcat; + let args = args.into_iter() + .map(|arg| { + // If the argument is not already an array, wrap it in an array + if !matches!(arg.get_type(&DataType::Null).unwrap_or_default(), DataType::List(_)) { + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![arg], + ))) + } else { + Ok(arg) + } + }) + .collect::>>()?; + + Ok(Expr::ScalarFunction(ScalarFunction::new( + array_concat, + args, + ))) + } impl ScalarUDFImpl for ConcatFunc { fn as_any(&self) -> &dyn Any { @@ -90,22 +120,40 @@ impl ScalarUDFImpl for ConcatFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; + + // If any argument is an array, return the array type + if arg_types.iter().any(|t| matches!(t, List(_))) { + // Find the first array type to use as the base + if let Some(DataType::List(field)) = arg_types.iter().find(|t| matches!(t, List(_))) { + return Ok(DataType::List(field.clone())); + } + } + + // Otherwise, use the existing string type logic let mut dt = &Utf8; - arg_types.iter().for_each(|data_type| { + for data_type in arg_types { if data_type == &Utf8View { dt = data_type; } if data_type == &LargeUtf8 && dt != &Utf8View { dt = data_type; } - }); - - Ok(dt.to_owned()) + } + Ok(dt.clone()) } /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // Check if any argument is an array + if Self::has_array_args(&args.args) { + // Convert to array_concat expression and evaluate it + let exprs: Vec = args.args.into_iter().map(|a| a.into_expr()).collect(); + let expr = Self::to_array_concat(exprs)?; + return expr.eval(args.context); + } + + // Original string concatenation logic let ScalarFunctionArgs { args, .. } = args; let mut return_datatype = DataType::Utf8;