Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions datafusion/functions/src/string/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ impl ScalarUDFImpl for AsciiFunc {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;

Ok(Int32)
Ok(DataType::Int32)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -186,6 +184,8 @@ mod tests {
test_ascii!(Some(String::from("a")), Ok(Some(97)));
test_ascii!(Some(String::from("")), Ok(Some(0)));
test_ascii!(Some(String::from("🚀")), Ok(Some(128640)));
test_ascii!(Some(String::from("\n")), Ok(Some(10)));
test_ascii!(Some(String::from("\t")), Ok(Some(9)));
Comment on lines +187 to +188
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consolidating the tests here, removing from Spark unit test (Spark slt still remains)

test_ascii!(None, Ok(None));
Ok(())
}
Expand Down
117 changes: 9 additions & 108 deletions datafusion/spark/src/function/string/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array};
use arrow::datatypes::DataType;
use arrow::error::ArrowError;
use datafusion_common::{internal_err, plan_err, Result};
use datafusion_common::Result;
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use datafusion_functions::string::ascii::ascii;
use datafusion_functions::utils::make_scalar_function;
use std::any::Any;
use std::sync::Arc;

/// <https://spark.apache.org/docs/latest/api/sql/index.html#ascii>
/// Spark compatible version of the [ascii] function. Differs from the [default ascii function]
/// in that it is more permissive of input types, for example casting numeric input to string
/// before executing the function (default version doesn't allow numeric input).
///
/// [ascii]: https://spark.apache.org/docs/latest/api/sql/index.html#ascii
/// [default ascii function]: datafusion_functions::string::ascii::AsciiFunc
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkAscii {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered removing this entirely and making AsciiFunc have a toggleable behaviour for this, but keeping it like this was easier to work with considering existing macros used to export the udf's

signature: Signature,
aliases: Vec<String>,
}

impl Default for SparkAscii {
Expand All @@ -42,7 +44,6 @@ impl SparkAscii {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![],
}
}
}
Expand All @@ -68,107 +69,7 @@ impl ScalarUDFImpl for SparkAscii {
make_scalar_function(ascii, vec![])(&args.args)
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return plan_err!(
"The {} function requires 1 argument, but got {}.",
self.name(),
arg_types.len()
);
}
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
Ok(vec![DataType::Utf8])
}
}

fn calculate_ascii<'a, V>(array: V) -> Result<ArrayRef, ArrowError>
where
V: ArrayAccessor<Item = &'a str>,
{
let iter = ArrayIter::new(array);
let result = iter
.map(|string| {
string.map(|s| {
let mut chars = s.chars();
chars.next().map_or(0, |v| v as i32)
})
})
.collect::<Int32Array>();

Ok(Arc::new(result) as ArrayRef)
}

/// Returns the numeric code of the first character of the argument.
pub fn ascii(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
Ok(calculate_ascii(string_array)?)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
Ok(calculate_ascii(string_array)?)
}
DataType::Utf8View => {
let string_array = args[0].as_string_view();
Ok(calculate_ascii(string_array)?)
}
_ => internal_err!("Unsupported data type"),
}
}

#[cfg(test)]
mod tests {
use crate::function::string::ascii::SparkAscii;
use crate::function::utils::test::test_scalar_function;
use arrow::array::{Array, Int32Array};
use arrow::datatypes::DataType::Int32;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};

macro_rules! test_ascii_string_invoke {
($INPUT:expr, $EXPECTED:expr) => {
test_scalar_function!(
SparkAscii::new(),
vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
$EXPECTED,
i32,
Int32,
Int32Array
);

test_scalar_function!(
SparkAscii::new(),
vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
$EXPECTED,
i32,
Int32,
Int32Array
);

test_scalar_function!(
SparkAscii::new(),
vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
$EXPECTED,
i32,
Int32,
Int32Array
);
};
}

#[test]
fn test_ascii_invoke() -> Result<()> {
test_ascii_string_invoke!(Some(String::from("x")), Ok(Some(120)));
test_ascii_string_invoke!(Some(String::from("a")), Ok(Some(97)));
test_ascii_string_invoke!(Some(String::from("")), Ok(Some(0)));
test_ascii_string_invoke!(Some(String::from("\n")), Ok(Some(10)));
test_ascii_string_invoke!(Some(String::from("\t")), Ok(Some(9)));
test_ascii_string_invoke!(None, Ok(None));

Ok(())
}
}
11 changes: 6 additions & 5 deletions datafusion/spark/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
//! # Example: using all function packages
//!
//! You can register all the functions in all packages using the [`register_all`]
//! function as shown below.
//! function as shown below. Any existing functions will be overwritten, with these
//! Spark functions taking priority.
//!
//! ```
//! # use datafusion_execution::FunctionRegistry;
Expand Down Expand Up @@ -68,10 +69,9 @@
//! # async fn stub() -> Result<()> {
//! // Create a new session context
//! let mut ctx = SessionContext::new();
//! // register all spark functions with the context
//! // Register all Spark functions with the context
//! datafusion_spark::register_all(&mut ctx)?;
//! // run a query. Note the `sha2` function is now available which
//! // has Spark semantics
//! // Run a query using the `sha2` function which is now available and has Spark semantics
//! let df = ctx.sql("SELECT sha2('The input String', 256)").await?;
//! # Ok(())
//! # }
Expand Down Expand Up @@ -170,7 +170,8 @@ pub fn all_default_table_functions() -> Vec<Arc<TableFunction>> {
function::table::functions()
}

/// Registers all enabled packages with a [`FunctionRegistry`]
/// Registers all enabled packages with a [`FunctionRegistry`], overriding any existing
/// functions if there is a name clash.
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
let scalar_functions: Vec<Arc<ScalarUDF>> = all_default_scalar_functions();
scalar_functions.into_iter().try_for_each(|udf| {
Expand Down