diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index d522158f7b6b..de81ec5f0bac 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -252,7 +252,21 @@ impl ScalarUDF { Ok(result) } - /// Get the circuits of inner implementation + /// Determines which of the arguments passed to this function are evaluated eagerly + /// and which may be evaluated lazily. + /// + /// See [ScalarUDFImpl::conditional_arguments] for more information. + pub fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.inner.conditional_arguments(args) + } + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// See [ScalarUDFImpl::short_circuits] for more information. pub fn short_circuits(&self) -> bool { self.inner.short_circuits() } @@ -656,10 +670,42 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// /// Setting this to true prevents certain optimizations such as common /// subexpression elimination + /// + /// When overriding this function to return `true`, [ScalarUDFImpl::conditional_arguments] can also be + /// overridden to report more accurately which arguments are eagerly evaluated and which ones + /// lazily. fn short_circuits(&self) -> bool { false } + /// Determines which of the arguments passed to this function are evaluated eagerly + /// and which may be evaluated lazily. + /// + /// If this function returns `None`, all arguments are eagerly evaluated. + /// Returning `None` is a micro optimization that saves a needless `Vec` + /// allocation. + /// + /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager` + /// are the arguments that are always evaluated, and `lazy` are the + /// arguments that may be evaluated lazily (i.e. may not be evaluated at all + /// in some cases). + /// + /// Implementations must ensure that the two returned `Vec`s are disjunct, + /// and that each argument from `args` is present in one the two `Vec`s. + /// + /// When overriding this function, [ScalarUDFImpl::short_circuits] must + /// be overridden to return `true`. + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + if self.short_circuits() { + Some((vec![], args.iter().collect())) + } else { + None + } + } + /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input /// intervals. /// @@ -845,6 +891,13 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.simplify(args, info) } + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.inner.conditional_arguments(args) + } + fn short_circuits(&self) -> bool { self.inner.short_circuits() } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 3fba539dd04b..aab1f445d559 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -47,7 +47,7 @@ use std::any::Any; )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct CoalesceFunc { - signature: Signature, + pub(super) signature: Signature, } impl Default for CoalesceFunc { @@ -126,6 +126,15 @@ impl ScalarUDFImpl for CoalesceFunc { internal_err!("coalesce should have been simplified to case") } + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + let eager = vec![&args[0]]; + let lazy = args[1..].iter().collect(); + Some((eager, lazy)) + } + fn short_circuits(&self) -> bool { true } diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index c8b34c4b1780..0b9968a88fc9 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -15,21 +15,19 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Array; -use arrow::compute::is_not_null; -use arrow::compute::kernels::zip::zip; -use arrow::datatypes::DataType; -use datafusion_common::{utils::take_function_args, Result}; +use crate::core::coalesce::CoalesceFunc; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion_common::Result; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use std::sync::Arc; #[user_doc( doc_section(label = "Conditional Functions"), - description = "Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.", + description = "Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_ and _expression2_ is not evaluated. This function can be used to substitute a default value for NULL values.", syntax_example = "nvl(expression1, expression2)", sql_example = r#"```sql > select nvl(null, 'a'); @@ -57,7 +55,7 @@ use std::sync::Arc; )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct NVLFunc { - signature: Signature, + coalesce: CoalesceFunc, aliases: Vec, } @@ -90,11 +88,13 @@ impl Default for NVLFunc { impl NVLFunc { pub fn new() -> Self { Self { - signature: Signature::uniform( - 2, - SUPPORTED_NVL_TYPES.to_vec(), - Volatility::Immutable, - ), + coalesce: CoalesceFunc { + signature: Signature::uniform( + 2, + SUPPORTED_NVL_TYPES.to_vec(), + Volatility::Immutable, + ), + }, aliases: vec![String::from("ifnull")], } } @@ -110,209 +110,45 @@ impl ScalarUDFImpl for NVLFunc { } fn signature(&self) -> &Signature { - &self.signature + &self.coalesce.signature } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + self.coalesce.return_type(arg_types) } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - nvl_func(&args.args) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn documentation(&self) -> Option<&Documentation> { - self.doc() + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.coalesce.return_field_from_args(args) } -} - -fn nvl_func(args: &[ColumnarValue]) -> Result { - let [lhs, rhs] = take_function_args("nvl/ifnull", args)?; - let (lhs_array, rhs_array) = match (lhs, rhs) { - (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - (Arc::clone(lhs), rhs.to_array_of_size(lhs.len())?) - } - (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { - (Arc::clone(lhs), Arc::clone(rhs)) - } - (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { - (lhs.to_array_of_size(rhs.len())?, Arc::clone(rhs)) - } - (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { - let mut current_value = lhs; - if lhs.is_null() { - current_value = rhs; - } - return Ok(ColumnarValue::Scalar(current_value.clone())); - } - }; - let to_apply = is_not_null(&lhs_array)?; - let value = zip(&to_apply, &lhs_array, &rhs_array)?; - Ok(ColumnarValue::Array(value)) -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::array::*; - use super::*; - use datafusion_common::ScalarValue; - - #[test] - fn nvl_int32() -> Result<()> { - let a = Int32Array::from(vec![ - Some(1), - Some(2), - None, - None, - Some(3), - None, - None, - Some(4), - Some(5), - ]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(6i32))); - - let result = nvl_func(&[a, lit_array])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(Int32Array::from(vec![ - Some(1), - Some(2), - Some(6), - Some(6), - Some(3), - Some(6), - Some(6), - Some(4), - Some(5), - ])) as ArrayRef; - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + self.coalesce.simplify(args, info) } - #[test] - // Ensure that arrays with no nulls can also invoke nvl() correctly - fn nvl_int32_non_nulls() -> Result<()> { - let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(20i32))); - - let result = nvl_func(&[a, lit_array])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(Int32Array::from(vec![ - Some(1), - Some(3), - Some(10), - Some(7), - Some(8), - Some(1), - Some(2), - Some(4), - Some(5), - ])) as ArrayRef; - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.coalesce.invoke_with_args(args) } - #[test] - fn nvl_boolean() -> Result<()> { - let a = BooleanArray::from(vec![Some(true), Some(false), None]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); - - let result = nvl_func(&[a, lit_array])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(BooleanArray::from(vec![ - Some(true), - Some(false), - Some(false), - ])) as ArrayRef; - - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.coalesce.conditional_arguments(args) } - #[test] - fn nvl_string() -> Result<()> { - let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::from("bax")); - - let result = nvl_func(&[a, lit_array])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(StringArray::from(vec![ - Some("foo"), - Some("bar"), - Some("bax"), - Some("baz"), - ])) as ArrayRef; - - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn short_circuits(&self) -> bool { + self.coalesce.short_circuits() } - #[test] - fn nvl_literal_first() -> Result<()> { - let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - - let result = nvl_func(&[lit_array, a])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(Int32Array::from(vec![ - Some(2), - Some(2), - Some(2), - Some(2), - Some(2), - Some(2), - ])) as ArrayRef; - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn aliases(&self) -> &[String] { + &self.aliases } - #[test] - fn nvl_scalar() -> Result<()> { - let a_null = ColumnarValue::Scalar(ScalarValue::Int32(None)); - let b_null = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - - let result_null = nvl_func(&[a_null, b_null])?; - let result_null = result_null - .into_array(1) - .expect("Failed to convert to array"); - - let expected_null = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; - - assert_eq!(expected_null.as_ref(), result_null.as_ref()); - - let a_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - let b_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); - - let result_nnull = nvl_func(&[a_nnull, b_nnull])?; - let result_nnull = result_nnull - .into_array(1) - .expect("Failed to convert to array"); - - let expected_nnull = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; - assert_eq!(expected_nnull.as_ref(), result_nnull.as_ref()); - - Ok(()) + fn documentation(&self) -> Option<&Documentation> { + self.doc() } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ec1f8f991a8e..251006849459 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -652,10 +652,8 @@ impl CSEController for ExprCSEController<'_> { // In case of `ScalarFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. - Expr::ScalarFunction(ScalarFunction { func, args }) - if func.short_circuits() => - { - Some((vec![], args.iter().collect())) + Expr::ScalarFunction(ScalarFunction { func, args }) => { + func.conditional_arguments(args) } // In case of `And` and `Or` the first child is surely executed, but we diff --git a/datafusion/sqllogictest/test_files/nvl.slt b/datafusion/sqllogictest/test_files/nvl.slt index daab54307cc2..f4225148ab78 100644 --- a/datafusion/sqllogictest/test_files/nvl.slt +++ b/datafusion/sqllogictest/test_files/nvl.slt @@ -148,3 +148,38 @@ query T SELECT NVL(arrow_cast('a', 'Utf8View'), NULL); ---- a + +# nvl is implemented as a case, and short-circuits evaluation +# so the following query should not error +query I +SELECT NVL(1, 1/0); +---- +1 + +# but this one should +query error DataFusion error: Arrow error: Divide by zero error +SELECT NVL(NULL, 1/0); + +# Expect the query plan to show nvl as a case expression +query I +select NVL(int_field, 9999) FROM test; +---- +1 +2 +3 +9999 +4 +9999 + +# Expect the query plan to show nvl as a case expression +query TT +EXPLAIN select NVL(int_field, 9999) FROM test; +---- +logical_plan +01)Projection: CASE WHEN __common_expr_1 IS NOT NULL THEN __common_expr_1 ELSE Int64(9999) END AS nvl(test.int_field,Int64(9999)) +02)--Projection: CAST(test.int_field AS Int64) AS __common_expr_1 +03)----TableScan: test projection=[int_field] +physical_plan +01)ProjectionExec: expr=[CASE WHEN __common_expr_1@0 IS NOT NULL THEN __common_expr_1@0 ELSE 9999 END as nvl(test.int_field,Int64(9999))] +02)--ProjectionExec: expr=[CAST(int_field@0 AS Int64) as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index fb67daa0b840..4d30f572ad6f 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -988,7 +988,7 @@ query TT EXPLAIN SELECT NVL(column1_utf8view, 'a') as c2 FROM test; ---- logical_plan -01)Projection: nvl(test.column1_utf8view, Utf8View("a")) AS c2 +01)Projection: CASE WHEN test.column1_utf8view IS NOT NULL THEN test.column1_utf8view ELSE Utf8View("a") END AS c2 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for nullif diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 9fcaac762855..ec2faf8b3d5d 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1056,7 +1056,7 @@ nullif(expression1, expression2) ### `nvl` -Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_. +Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_ and _expression2_ is not evaluated. This function can be used to substitute a default value for NULL values. ```sql nvl(expression1, expression2)