Skip to content

Commit

Permalink
Validate and unpack function arguments tersely
Browse files Browse the repository at this point in the history
Add a `take_function_args` helper that provides convenient unpacking of
function arguments along with validation that the provided argument
count matches the expected.  A few functions are updated to leverage the
new pattern to demonstrate its usefulness.
  • Loading branch information
findepi committed Feb 5, 2025
1 parent 2ad28e0 commit 8d251d5
Show file tree
Hide file tree
Showing 13 changed files with 86 additions and 120 deletions.
6 changes: 3 additions & 3 deletions datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::{
use datafusion_common::{exec_datafusion_err, DataFusionError};
use std::any::Any;

use crate::utils::take_function_args;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl,
Expand Down Expand Up @@ -117,10 +118,9 @@ impl ScalarUDFImpl for ArrowCastFunc {
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
let nullable = args.nullables.iter().any(|&nullable| nullable);

// Length check handled in the signature
debug_assert_eq!(args.scalar_arguments.len(), 2);
let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;

args.scalar_arguments[1]
type_arg
.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
.map_or_else(
|| {
Expand Down
13 changes: 4 additions & 9 deletions datafusion/functions/src/core/arrowtypeof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
// under the License.

use arrow::datatypes::DataType;
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Documentation};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
use std::any::Any;
use crate::utils::take_function_args;

#[user_doc(
doc_section(label = "Other Functions"),
Expand Down Expand Up @@ -80,14 +81,8 @@ impl ScalarUDFImpl for ArrowTypeOfFunc {
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
if args.len() != 1 {
return exec_err!(
"arrow_typeof function requires 1 arguments, got {}",
args.len()
);
}

let input_data_type = args[0].data_type();
let [arg] = take_function_args(self.name(), args)?;
let input_data_type = arg.data_type();
Ok(ColumnarValue::Scalar(ScalarValue::from(format!(
"{input_data_type}"
))))
Expand Down
12 changes: 4 additions & 8 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
use std::any::Any;
use std::sync::Arc;
use crate::utils::take_function_args;

#[user_doc(
doc_section(label = "Other Functions"),
Expand Down Expand Up @@ -99,14 +100,9 @@ impl ScalarUDFImpl for GetFieldFunc {
}

fn display_name(&self, args: &[Expr]) -> Result<String> {
if args.len() != 2 {
return exec_err!(
"get_field function requires 2 arguments, got {}",
args.len()
);
}
let [base, field_name] = take_function_args(self.name(), args)?;

let name = match &args[1] {
let name = match field_name {
Expr::Literal(name) => name,
_ => {
return exec_err!(
Expand All @@ -115,7 +111,7 @@ impl ScalarUDFImpl for GetFieldFunc {
}
};

Ok(format!("{}[{}]", args[0], name))
Ok(format!("{}[{}]", base, name))
}

fn schema_name(&self, args: &[Expr]) -> Result<String> {
Expand Down
13 changes: 4 additions & 9 deletions datafusion/functions/src/core/nullif.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use arrow::datatypes::DataType;
use datafusion_common::{exec_err, Result};
use datafusion_common::{Result};
use datafusion_expr::{ColumnarValue, Documentation};

use arrow::compute::kernels::cmp::eq;
Expand All @@ -25,6 +25,8 @@ use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
use std::any::Any;
use crate::utils::take_function_args;

#[user_doc(
doc_section(label = "Conditional Functions"),
description = "Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_.
Expand Down Expand Up @@ -119,14 +121,7 @@ impl ScalarUDFImpl for NullIfFunc {
/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
///
fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return exec_err!(
"{:?} args were supplied but NULLIF takes exactly two args",
args.len()
);
}

let (lhs, rhs) = (&args[0], &args[1]);
let [lhs, rhs] = take_function_args("nullif", args)?;

match (lhs, rhs) {
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
Expand Down
13 changes: 5 additions & 8 deletions datafusion/functions/src/core/nvl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ use arrow::array::Array;
use arrow::compute::is_not_null;
use arrow::compute::kernels::zip::zip;
use arrow::datatypes::DataType;
use datafusion_common::{internal_err, Result};
use datafusion_common::{Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;
use crate::utils::take_function_args;

#[user_doc(
doc_section(label = "Conditional Functions"),
description = "Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.",
Expand Down Expand Up @@ -133,13 +135,8 @@ impl ScalarUDFImpl for NVLFunc {
}

fn nvl_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return internal_err!(
"{:?} args were supplied but NVL/IFNULL takes exactly two args",
args.len()
);
}
let (lhs_array, rhs_array) = match (&args[0], &args[1]) {
let [lhs, rhs] = take_function_args("nvl/ifnull", args)?;
let (lhs_array, rhs_array) = match (lhs, rhs) {
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
(Arc::clone(lhs), rhs.to_array_of_size(lhs.len())?)
}
Expand Down
14 changes: 5 additions & 9 deletions datafusion/functions/src/core/nvl2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ use arrow::array::Array;
use arrow::compute::is_not_null;
use arrow::compute::kernels::zip::zip;
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, internal_err, Result};
use datafusion_common::{internal_err, Result};
use datafusion_expr::{
type_coercion::binary::comparison_coercion, ColumnarValue, Documentation,
ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;
use crate::utils::take_function_args;

#[user_doc(
doc_section(label = "Conditional Functions"),
Expand Down Expand Up @@ -104,14 +105,9 @@ impl ScalarUDFImpl for NVL2Func {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 3 {
return exec_err!(
"NVL2 takes exactly three arguments, but got {}",
arg_types.len()
);
}
let new_type = arg_types.iter().skip(1).try_fold(
arg_types.first().unwrap().clone(),
let [a, b, c] = take_function_args(self.name(), arg_types)?;
let new_type = [b, c].iter().try_fold(
a.clone(),
|acc, x| {
// The coerced types found by `comparison_coercion` are not guaranteed to be
// coercible for the arguments. `comparison_coercion` returns more loose
Expand Down
15 changes: 6 additions & 9 deletions datafusion/functions/src/core/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
//! [`VersionFunc`]: Implementation of the `version` function.
use arrow::datatypes::DataType;
use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
use crate::utils::take_function_args;

#[user_doc(
doc_section(label = "Other Functions"),
description = "Returns the version of DataFusion.",
Expand Down Expand Up @@ -70,21 +72,16 @@ impl ScalarUDFImpl for VersionFunc {
}

fn return_type(&self, args: &[DataType]) -> Result<DataType> {
if args.is_empty() {
Ok(DataType::Utf8)
} else {
plan_err!("version expects no arguments")
}
let [] = take_function_args(self.name(), args)?;
Ok(DataType::Utf8)
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
if !args.is_empty() {
return internal_err!("{} function does not accept arguments", self.name());
}
let [] = take_function_args(self.name(), args)?;
// TODO it would be great to add rust version and arrow version,
// but that requires a `build.rs` script and/or adding a version const to arrow-rs
let version = format!(
Expand Down
33 changes: 8 additions & 25 deletions datafusion/functions/src/crypto/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,14 @@ use sha2::{Sha224, Sha256, Sha384, Sha512};
use std::fmt::{self, Write};
use std::str::FromStr;
use std::sync::Arc;
use crate::utils::take_function_args;

macro_rules! define_digest_function {
($NAME: ident, $METHOD: ident, $DOC: expr) => {
#[doc = $DOC]
pub fn $NAME(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 1 {
return exec_err!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
DigestAlgorithm::$METHOD.to_string()
);
}
digest_process(&args[0], DigestAlgorithm::$METHOD)
let [data] = take_function_args(&DigestAlgorithm::$METHOD.to_string(), args)?;
digest_process(data, DigestAlgorithm::$METHOD)
}
};
}
Expand Down Expand Up @@ -114,13 +109,8 @@ pub enum DigestAlgorithm {
/// Second argument is the algorithm to use.
/// Standard algorithms are md5, sha1, sha224, sha256, sha384 and sha512.
pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return exec_err!(
"{:?} args were supplied but digest takes exactly two arguments",
args.len()
);
}
let digest_algorithm = match &args[1] {
let [data, digest_algorithm] = take_function_args("digest", args)?;
let digest_algorithm = match digest_algorithm {
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(Some(method)) => method.parse::<DigestAlgorithm>(),
_ => exec_err!("Unsupported data type {scalar:?} for function digest"),
Expand All @@ -129,7 +119,7 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
internal_err!("Digest using dynamically decided method is not yet supported")
}
}?;
digest_process(&args[0], digest_algorithm)
digest_process(data, digest_algorithm)
}

impl FromStr for DigestAlgorithm {
Expand Down Expand Up @@ -175,15 +165,8 @@ impl fmt::Display for DigestAlgorithm {

/// computes md5 hash digest of the given input
pub fn md5(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 1 {
return exec_err!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
DigestAlgorithm::Md5
);
}

let value = digest_process(&args[0], DigestAlgorithm::Md5)?;
let [data] = take_function_args("md5", args)?;
let value = digest_process(data, DigestAlgorithm::Md5)?;

// md5 requires special handling because of its unique utf8 return type
Ok(match value {
Expand Down
6 changes: 3 additions & 3 deletions datafusion/functions/src/datetime/date_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use datafusion_expr::{
};
use datafusion_expr_common::signature::TypeSignatureClass;
use datafusion_macros::user_doc;
use crate::utils::take_function_args;

#[user_doc(
doc_section(label = "Time and Date Functions"),
Expand Down Expand Up @@ -140,10 +141,9 @@ impl ScalarUDFImpl for DatePartFunc {
}

fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
// Length check handled in the signature
debug_assert_eq!(args.scalar_arguments.len(), 2);
let [field, _] = take_function_args(self.name(), args.scalar_arguments)?;

args.scalar_arguments[0]
field
.and_then(|sv| {
sv.try_as_str()
.flatten()
Expand Down
16 changes: 6 additions & 10 deletions datafusion/functions/src/datetime/make_date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use crate::utils::take_function_args;

#[user_doc(
doc_section(label = "Time and Date Functions"),
Expand Down Expand Up @@ -111,13 +112,6 @@ impl ScalarUDFImpl for MakeDateFunc {
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
if args.len() != 3 {
return exec_err!(
"make_date function requires 3 arguments, got {}",
args.len()
);
}

// first, identify if any of the arguments is an Array. If yes, store its `len`,
// as any scalar will need to be converted to an array of len `len`.
let len = args
Expand All @@ -127,9 +121,11 @@ impl ScalarUDFImpl for MakeDateFunc {
ColumnarValue::Array(a) => Some(a.len()),
});

let years = args[0].cast_to(&Int32, None)?;
let months = args[1].cast_to(&Int32, None)?;
let days = args[2].cast_to(&Int32, None)?;
let [years, months, days] = take_function_args(self.name(), args)?;

let years = years.cast_to(&Int32, None)?;
let months = months.cast_to(&Int32, None)?;
let days = days.cast_to(&Int32, None)?;

let scalar_value_fn = |col: &ColumnarValue| -> Result<i32> {
let ColumnarValue::Scalar(s) = col else {
Expand Down
16 changes: 6 additions & 10 deletions datafusion/functions/src/datetime/to_char.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD,
};
use datafusion_macros::user_doc;
use crate::utils::take_function_args;

#[user_doc(
doc_section(label = "Time and Date Functions"),
Expand Down Expand Up @@ -140,28 +141,23 @@ impl ScalarUDFImpl for ToCharFunc {
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
if args.len() != 2 {
return exec_err!(
"to_char function requires 2 arguments, got {}",
args.len()
);
}
let [date_time, format] = take_function_args(self.name(), args)?;

match &args[1] {
match format {
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Null) => {
_to_char_scalar(args[0].clone(), None)
_to_char_scalar(date_time.clone(), None)
}
// constant format
ColumnarValue::Scalar(ScalarValue::Utf8(Some(format))) => {
// invoke to_char_scalar with the known string, without converting to array
_to_char_scalar(args[0].clone(), Some(format))
_to_char_scalar(date_time.clone(), Some(format))
}
ColumnarValue::Array(_) => _to_char_array(args),
_ => {
exec_err!(
"Format for `to_char` must be non-null Utf8, received {:?}",
args[1].data_type()
format.data_type()
)
}
}
Expand Down
Loading

0 comments on commit 8d251d5

Please sign in to comment.