diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 3b9778a6e97e..bdf30833127a 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -87,9 +87,7 @@ impl ScalarUDFImpl for AsciiFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - use DataType::*; - - Ok(Int32) + Ok(DataType::Int32) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -186,6 +184,8 @@ mod tests { test_ascii!(Some(String::from("a")), Ok(Some(97))); test_ascii!(Some(String::from("")), Ok(Some(0))); test_ascii!(Some(String::from("🚀")), Ok(Some(128640))); + test_ascii!(Some(String::from("\n")), Ok(Some(10))); + test_ascii!(Some(String::from("\t")), Ok(Some(9))); test_ascii!(None, Ok(None)); Ok(()) } diff --git a/datafusion/spark/src/function/string/ascii.rs b/datafusion/spark/src/function/string/ascii.rs index bf66a19738a1..f14a66d4e484 100644 --- a/datafusion/spark/src/function/string/ascii.rs +++ b/datafusion/spark/src/function/string/ascii.rs @@ -15,21 +15,23 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; use arrow::datatypes::DataType; -use arrow::error::ArrowError; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::Result; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_functions::string::ascii::ascii; use datafusion_functions::utils::make_scalar_function; use std::any::Any; -use std::sync::Arc; -/// +/// Spark compatible version of the [ascii] function. Differs from the [default ascii function] +/// in that it is more permissive of input types, for example casting numeric input to string +/// before executing the function (default version doesn't allow numeric input). +/// +/// [ascii]: https://spark.apache.org/docs/latest/api/sql/index.html#ascii +/// [default ascii function]: datafusion_functions::string::ascii::AsciiFunc #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkAscii { signature: Signature, - aliases: Vec, } impl Default for SparkAscii { @@ -42,7 +44,6 @@ impl SparkAscii { pub fn new() -> Self { Self { signature: Signature::user_defined(Volatility::Immutable), - aliases: vec![], } } } @@ -68,107 +69,7 @@ impl ScalarUDFImpl for SparkAscii { make_scalar_function(ascii, vec![])(&args.args) } - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return plan_err!( - "The {} function requires 1 argument, but got {}.", - self.name(), - arg_types.len() - ); - } + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { Ok(vec![DataType::Utf8]) } } - -fn calculate_ascii<'a, V>(array: V) -> Result -where - V: ArrayAccessor, -{ - let iter = ArrayIter::new(array); - let result = iter - .map(|string| { - string.map(|s| { - let mut chars = s.chars(); - chars.next().map_or(0, |v| v as i32) - }) - }) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns the numeric code of the first character of the argument. -pub fn ascii(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Utf8 => { - let string_array = args[0].as_string::(); - Ok(calculate_ascii(string_array)?) - } - DataType::LargeUtf8 => { - let string_array = args[0].as_string::(); - Ok(calculate_ascii(string_array)?) - } - DataType::Utf8View => { - let string_array = args[0].as_string_view(); - Ok(calculate_ascii(string_array)?) - } - _ => internal_err!("Unsupported data type"), - } -} - -#[cfg(test)] -mod tests { - use crate::function::string::ascii::SparkAscii; - use crate::function::utils::test::test_scalar_function; - use arrow::array::{Array, Int32Array}; - use arrow::datatypes::DataType::Int32; - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - - macro_rules! test_ascii_string_invoke { - ($INPUT:expr, $EXPECTED:expr) => { - test_scalar_function!( - SparkAscii::new(), - vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], - $EXPECTED, - i32, - Int32, - Int32Array - ); - - test_scalar_function!( - SparkAscii::new(), - vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], - $EXPECTED, - i32, - Int32, - Int32Array - ); - - test_scalar_function!( - SparkAscii::new(), - vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], - $EXPECTED, - i32, - Int32, - Int32Array - ); - }; - } - - #[test] - fn test_ascii_invoke() -> Result<()> { - test_ascii_string_invoke!(Some(String::from("x")), Ok(Some(120))); - test_ascii_string_invoke!(Some(String::from("a")), Ok(Some(97))); - test_ascii_string_invoke!(Some(String::from("")), Ok(Some(0))); - test_ascii_string_invoke!(Some(String::from("\n")), Ok(Some(10))); - test_ascii_string_invoke!(Some(String::from("\t")), Ok(Some(9))); - test_ascii_string_invoke!(None, Ok(None)); - - Ok(()) - } -} diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs index 1217b81e5a25..4d45f3c482af 100644 --- a/datafusion/spark/src/lib.rs +++ b/datafusion/spark/src/lib.rs @@ -37,7 +37,8 @@ //! # Example: using all function packages //! //! You can register all the functions in all packages using the [`register_all`] -//! function as shown below. +//! function as shown below. Any existing functions will be overwritten, with these +//! Spark functions taking priority. //! //! ``` //! # use datafusion_execution::FunctionRegistry; @@ -68,10 +69,9 @@ //! # async fn stub() -> Result<()> { //! // Create a new session context //! let mut ctx = SessionContext::new(); -//! // register all spark functions with the context +//! // Register all Spark functions with the context //! datafusion_spark::register_all(&mut ctx)?; -//! // run a query. Note the `sha2` function is now available which -//! // has Spark semantics +//! // Run a query using the `sha2` function which is now available and has Spark semantics //! let df = ctx.sql("SELECT sha2('The input String', 256)").await?; //! # Ok(()) //! # } @@ -170,7 +170,8 @@ pub fn all_default_table_functions() -> Vec> { function::table::functions() } -/// Registers all enabled packages with a [`FunctionRegistry`] +/// Registers all enabled packages with a [`FunctionRegistry`], overriding any existing +/// functions if there is a name clash. pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let scalar_functions: Vec> = all_default_scalar_functions(); scalar_functions.into_iter().try_for_each(|udf| {