Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 61 additions & 13 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, as_largestring_array};
use arrow::datatypes::DataType;
use arrow::array::{Array, as_largestring_array, ArrayRef};
use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::*;
use datafusion_expr::sort_properties::ExprProperties;
use std::any::Any;
use std::sync::Arc;
Expand All @@ -26,10 +27,10 @@ use crate::strings::{
ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
};
use datafusion_common::cast::{as_string_array, as_string_view_array};
use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
use datafusion_expr::expr::ScalarFunction;
use datafusion_common::{Result, ScalarValue, internal_err, plan_err, exec_err};
use datafusion_expr::expr::{ScalarFunction, ScalarFunctionExpr};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit, BuiltinScalarFunction};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BuiltinScalarFunction has been removed from DataFusion long time ago...

use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
use datafusion_macros::user_doc;

Expand Down Expand Up @@ -67,13 +68,42 @@ impl ConcatFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::variadic(
vec![Utf8View, Utf8, LargeUtf8],
Volatility::Immutable,
signature: Signature::variadic_any(Volatility::Immutable)
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
),
,

}
}
}

// Check if any argument is an array type
fn has_array_args(args: &[ColumnarValue]) -> bool {
use arrow::datatypes::DataType::*;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use arrow::datatypes::DataType::*;

args.iter().any(|arg| match arg {
ColumnarValue::Array(arr) => matches!(arr.data_type(), List(_)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ColumnarValue::Array(arr) => matches!(arr.data_type(), List(_)),
ColumnarValue::Array(arr) => matches!(arr.data_type(), DataType::List(_)),

ColumnarValue::Scalar(scalar) => matches!(scalar, ScalarValue::List(_, _)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List(Arc<ListArray>),
- ScalarValue::List(_) accepts just one parameter

})
}

// Convert arguments to array_concat function
fn to_array_concat(args: Vec<Expr>) -> Result<Expr> {
let array_concat = BuiltinScalarFunction::ArrayConcat;
let args = args.into_iter()
.map(|arg| {
// If the argument is not already an array, wrap it in an array
if !matches!(arg.get_type(&DataType::Null).unwrap_or_default(), DataType::List(_)) {
Ok(Expr::ScalarFunction(ScalarFunction::new(
BuiltinScalarFunction::MakeArray,
vec![arg],
)))
} else {
Ok(arg)
}
})
.collect::<Result<Vec<_>>>()?;

Ok(Expr::ScalarFunction(ScalarFunction::new(
array_concat,
args,
)))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
}
}

to close the impl block


impl ScalarUDFImpl for ConcatFunc {
fn as_any(&self) -> &dyn Any {
Expand All @@ -90,22 +120,40 @@ impl ScalarUDFImpl for ConcatFunc {

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;

// If any argument is an array, return the array type
if arg_types.iter().any(|t| matches!(t, List(_))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if is not really needed. The inner one does the same.

// Find the first array type to use as the base
if let Some(DataType::List(field)) = arg_types.iter().find(|t| matches!(t, List(_))) {
return Ok(DataType::List(field.clone()));
}
}

// Otherwise, use the existing string type logic
let mut dt = &Utf8;
arg_types.iter().for_each(|data_type| {
for data_type in arg_types {
if data_type == &Utf8View {
dt = data_type;
}
if data_type == &LargeUtf8 && dt != &Utf8View {
dt = data_type;
}
});

Ok(dt.to_owned())
}
Ok(dt.clone())
}

/// Concatenates the text representations of all the arguments. NULL arguments are ignored.
/// concat('abcde', 2, NULL, 22) = 'abcde222'
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
// Check if any argument is an array
if Self::has_array_args(&args.args) {
// Convert to array_concat expression and evaluate it
let exprs: Vec<Expr> = args.args.into_iter().map(|a| a.into_expr()).collect();
let expr = Self::to_array_concat(exprs)?;
return expr.eval(args.context);
}

// Original string concatenation logic
let ScalarFunctionArgs { args, .. } = args;

let mut return_datatype = DataType::Utf8;
Expand Down