diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index b42bef682b0d..ced9496f2174 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -65,9 +65,9 @@ unicode-segmentation = { version = "^1.7.1", optional = true } regex = { version = "^1.4.3", optional = true } lazy_static = { version = "^1.4.0", optional = true } smallvec = { version = "1.6", features = ["union"] } +rand = "0.8" [dev-dependencies] -rand = "0.8" criterion = "0.3" tempfile = "3" doc-comment = "0.3" diff --git a/datafusion/src/physical_plan/crypto_expressions.rs b/datafusion/src/physical_plan/crypto_expressions.rs index 8ad876b24d0c..dfb73f7af1da 100644 --- a/datafusion/src/physical_plan/crypto_expressions.rs +++ b/datafusion/src/physical_plan/crypto_expressions.rs @@ -16,7 +16,6 @@ // under the License. //! Crypto expressions -use std::sync::Arc; use md5::Md5; use sha2::{ @@ -97,23 +96,15 @@ where { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_binary_function::< - i32, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } + DataType::Utf8 => Ok(ColumnarValue::from( + unary_binary_function::(&[a.as_ref()], op, name)?, + )), DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_binary_function::< - i64, - _, - _, - >( - &[a.as_ref()], op, name - )?))) + Ok(ColumnarValue::from(unary_binary_function::( + &[a.as_ref()], + op, + name, + )?)) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function {}", @@ -147,13 +138,9 @@ fn md5_array( pub fn md5(args: &[ColumnarValue]) -> Result { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new(md5_array::(&[ - a.as_ref() - ])?))), + DataType::Utf8 => Ok(ColumnarValue::from(md5_array::(&[a.as_ref()])?)), DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(md5_array::(&[ - a.as_ref() - ])?))) + Ok(ColumnarValue::from(md5_array::(&[a.as_ref()])?)) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function md5", diff --git a/datafusion/src/physical_plan/datetime_expressions.rs b/datafusion/src/physical_plan/datetime_expressions.rs index 7b5816186f27..f4cc312bee1e 100644 --- a/datafusion/src/physical_plan/datetime_expressions.rs +++ b/datafusion/src/physical_plan/datetime_expressions.rs @@ -231,12 +231,24 @@ where { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new( - unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, - ))), - DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( - unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, - ))), + DataType::Utf8 => { + Ok(ColumnarValue::from(unary_string_to_primitive_function::< + i32, + O, + _, + >( + &[a.as_ref()], op, name + )?)) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::from(unary_string_to_primitive_function::< + i64, + O, + _, + >( + &[a.as_ref()], op, name + )?)) + } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function {}", other, name, @@ -342,7 +354,7 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { .map(f) .collect::>()?; - ColumnarValue::Array(Arc::new(array)) + ColumnarValue::from(array) } }) } @@ -435,7 +447,7 @@ pub fn date_part(args: &[ColumnarValue]) -> Result { 0, )?) } else { - ColumnarValue::Array(Arc::new(arr)) + ColumnarValue::from(arr) }) } @@ -461,8 +473,7 @@ mod tests { ts_builder.append_null()?; let expected_timestamps = &ts_builder.finish() as &dyn Array; - let string_array = - ColumnarValue::Array(Arc::new(string_builder.finish()) as ArrayRef); + let string_array = ColumnarValue::from(string_builder.finish() as ArrayRef); let parsed_timestamps = to_timestamp(&[string_array]) .expect("that to_timestamp parsed values without error"); if let ColumnarValue::Array(parsed_array) = parsed_timestamps { @@ -539,7 +550,7 @@ mod tests { let mut builder = Int64Array::builder(1); builder.append_value(1)?; - let int64array = ColumnarValue::Array(Arc::new(builder.finish())); + let int64array = ColumnarValue::from(builder.finish()); let expected_err = "Internal error: Unsupported data type Int64 for function to_timestamp"; diff --git a/datafusion/src/physical_plan/expressions/in_list.rs b/datafusion/src/physical_plan/expressions/in_list.rs index 41f111006ea2..717ea07f43e5 100644 --- a/datafusion/src/physical_plan/expressions/in_list.rs +++ b/datafusion/src/physical_plan/expressions/in_list.rs @@ -69,7 +69,7 @@ macro_rules! make_contains { }) .collect::>(); - Ok(ColumnarValue::Array(Arc::new( + Ok(ColumnarValue::from( array .iter() .map(|x| { @@ -95,7 +95,7 @@ macro_rules! make_contains { } }) .collect::(), - ))) + )) }}; } @@ -164,7 +164,7 @@ impl InListExpr { }) .collect::>(); - Ok(ColumnarValue::Array(Arc::new( + Ok(ColumnarValue::from( array .iter() .map(|x| { @@ -190,7 +190,7 @@ impl InListExpr { } }) .collect::(), - ))) + )) } } diff --git a/datafusion/src/physical_plan/expressions/is_not_null.rs b/datafusion/src/physical_plan/expressions/is_not_null.rs index 7ac2110b5022..9effa9506a7a 100644 --- a/datafusion/src/physical_plan/expressions/is_not_null.rs +++ b/datafusion/src/physical_plan/expressions/is_not_null.rs @@ -70,9 +70,9 @@ impl PhysicalExpr for IsNotNullExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; match arg { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( - compute::is_not_null(array.as_ref())?, - ))), + ColumnarValue::Array(array) => { + Ok(ColumnarValue::from(compute::is_not_null(array.as_ref())?)) + } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( ScalarValue::Boolean(Some(!scalar.is_null())), )), diff --git a/datafusion/src/physical_plan/expressions/is_null.rs b/datafusion/src/physical_plan/expressions/is_null.rs index dfa53f3f7d26..92ad0f08752e 100644 --- a/datafusion/src/physical_plan/expressions/is_null.rs +++ b/datafusion/src/physical_plan/expressions/is_null.rs @@ -70,9 +70,9 @@ impl PhysicalExpr for IsNullExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; match arg { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( - compute::is_null(array.as_ref())?, - ))), + ColumnarValue::Array(array) => { + Ok(ColumnarValue::from(compute::is_null(array.as_ref())?)) + } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( ScalarValue::Boolean(Some(scalar.is_null())), )), diff --git a/datafusion/src/physical_plan/expressions/not.rs b/datafusion/src/physical_plan/expressions/not.rs index 23a1a46651de..14a655605d88 100644 --- a/datafusion/src/physical_plan/expressions/not.rs +++ b/datafusion/src/physical_plan/expressions/not.rs @@ -81,9 +81,9 @@ impl PhysicalExpr for NotExpr { "boolean_op failed to downcast array".to_owned(), ) })?; - Ok(ColumnarValue::Array(Arc::new( - arrow::compute::kernels::boolean::not(array)?, - ))) + Ok(ColumnarValue::from(arrow::compute::kernels::boolean::not( + array, + )?)) } ColumnarValue::Scalar(scalar) => { use std::convert::TryInto; diff --git a/datafusion/src/physical_plan/expressions/nullif.rs b/datafusion/src/physical_plan/expressions/nullif.rs index 7cc58ed2318f..ccee070300ac 100644 --- a/datafusion/src/physical_plan/expressions/nullif.rs +++ b/datafusion/src/physical_plan/expressions/nullif.rs @@ -138,7 +138,7 @@ mod tests { Some(4), Some(5), ]); - let a = ColumnarValue::Array(Arc::new(a)); + let a = ColumnarValue::from(a); let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); @@ -164,7 +164,7 @@ mod tests { // Ensure that arrays with no nulls can also invoke NULLIF() correctly fn nullif_int32_nonulls() -> Result<()> { let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); - let a = ColumnarValue::Array(Arc::new(a)); + let a = ColumnarValue::from(a); let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 960d7c5d8e0d..76adf7056e98 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -43,7 +43,7 @@ use crate::{ scalar::ScalarValue, }; use arrow::{ - array::ArrayRef, + array::{ArrayRef, NullArray}, compute::kernels::length::{bit_length, length}, datatypes::TimeUnit, datatypes::{DataType, Field, Int32Type, Int64Type, Schema}, @@ -160,6 +160,8 @@ pub enum BuiltinScalarFunction { NullIf, /// octet_length OctetLength, + /// random + Random, /// regexp_replace RegexpReplace, /// repeat @@ -256,6 +258,7 @@ impl FromStr for BuiltinScalarFunction { "md5" => BuiltinScalarFunction::MD5, "nullif" => BuiltinScalarFunction::NullIf, "octet_length" => BuiltinScalarFunction::OctetLength, + "random" => BuiltinScalarFunction::Random, "regexp_replace" => BuiltinScalarFunction::RegexpReplace, "repeat" => BuiltinScalarFunction::Repeat, "replace" => BuiltinScalarFunction::Replace, @@ -298,15 +301,6 @@ pub fn return_type( // verify that this is a valid set of data types for this function data_types(&arg_types, &signature(fun))?; - if arg_types.is_empty() { - // functions currently cannot be evaluated without arguments, as they can't - // know the number of rows to return. - return Err(DataFusionError::Plan(format!( - "Function '{}' requires at least one argument", - fun - ))); - } - // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { @@ -427,6 +421,7 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::Random => Ok(DataType::Float64), BuiltinScalarFunction::RegexpReplace => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -729,6 +724,7 @@ pub fn create_physical_expr( BuiltinScalarFunction::Ln => math_expressions::ln, BuiltinScalarFunction::Log10 => math_expressions::log10, BuiltinScalarFunction::Log2 => math_expressions::log2, + BuiltinScalarFunction::Random => math_expressions::random, BuiltinScalarFunction::Round => math_expressions::round, BuiltinScalarFunction::Signum => math_expressions::signum, BuiltinScalarFunction::Sin => math_expressions::sin, @@ -1373,12 +1369,16 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - // evaluate the arguments - let inputs = self - .args - .iter() - .map(|e| e.evaluate(batch)) - .collect::>>()?; + // evaluate the arguments, if there are no arguments we'll instead pass in a null array of + // batch size (as a convention) + let inputs = match self.args.len() { + 0 => vec![ColumnarValue::from(NullArray::new(batch.num_rows()))], + _ => self + .args + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?, + }; // evaluate the function let fun = self.fun.as_ref(); @@ -1386,7 +1386,7 @@ impl PhysicalExpr for ScalarFunctionExpr { } } -/// decorates a function to handle [`ScalarValue`]s by coverting them to arrays before calling the function +/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function /// and vice-versa after evaluation. pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation where diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion/src/physical_plan/math_expressions.rs index 0e0bed2deac2..3ac18dd5a53f 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion/src/physical_plan/math_expressions.rs @@ -16,11 +16,11 @@ // under the License. //! Math expressions - use super::{ColumnarValue, ScalarValue}; use crate::error::{DataFusionError, Result}; use arrow::array::{Float32Array, Float64Array}; use arrow::datatypes::DataType; +use rand::{thread_rng, Rng}; use std::sync::Arc; macro_rules! downcast_compute_op { @@ -100,3 +100,21 @@ math_unary_function!("exp", exp); math_unary_function!("ln", ln); math_unary_function!("log2", log2); math_unary_function!("log10", log10); + +/// random SQL function +pub fn random(args: &[ColumnarValue]) -> Result { + let len = match &args[0] { + ColumnarValue::Array(array) => array.len(), + _ => { + return Err(DataFusionError::Internal( + "Expect random function to take no param".to_string(), + )) + } + }; + let mut rng = thread_rng(); + let mut array = Vec::with_capacity(len); + for _ in 0..len { + array.push(Some(rng.gen_range(0.0..1.0))) + } + Ok(ColumnarValue::from(Float64Array::from(array))) +} diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index a8f6f0c35f00..2e2dbf7d4d4e 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -28,7 +28,13 @@ use crate::{error::Result, scalar::ScalarValue}; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use arrow::{array::ArrayRef, datatypes::Field}; +use arrow::{ + array::{ + ArrayRef, BinaryOffsetSizeTrait, BooleanArray, GenericBinaryArray, + GenericStringArray, NullArray, PrimitiveArray, StringOffsetSizeTrait, + }, + datatypes::{ArrowPrimitiveType, Field}, +}; use async_trait::async_trait; use futures::stream::Stream; @@ -251,6 +257,35 @@ impl ColumnarValue { } } +impl From for ColumnarValue { + fn from(array: NullArray) -> Self { + ColumnarValue::Array(Arc::new(array)) + } +} +impl From> for ColumnarValue { + fn from(array: GenericBinaryArray) -> Self { + ColumnarValue::Array(Arc::new(array)) + } +} + +impl From for ColumnarValue { + fn from(array: BooleanArray) -> Self { + ColumnarValue::Array(Arc::new(array)) + } +} + +impl From> for ColumnarValue { + fn from(array: GenericStringArray) -> Self { + ColumnarValue::Array(Arc::new(array)) + } +} + +impl From> for ColumnarValue { + fn from(array: PrimitiveArray) -> Self { + ColumnarValue::Array(Arc::new(array)) + } +} + /// Expression that can be evaluated against a RecordBatch /// A Physical expression knows its type, nullability and how to evaluate itself. pub trait PhysicalExpr: Send + Sync + Display + Debug { diff --git a/datafusion/src/physical_plan/string_expressions.rs b/datafusion/src/physical_plan/string_expressions.rs index 882fe30502fd..3f933e43661b 100644 --- a/datafusion/src/physical_plan/string_expressions.rs +++ b/datafusion/src/physical_plan/string_expressions.rs @@ -131,24 +131,14 @@ where match &args[0] { ColumnarValue::Array(a) => match a.data_type() { DataType::Utf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i32, - i32, - _, - _, - >( - &[a.as_ref()], op, name - )?))) + Ok(ColumnarValue::from( + unary_string_function::(&[a.as_ref()], op, name)?, + )) } DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i64, - i64, - _, - _, - >( - &[a.as_ref()], op, name - )?))) + Ok(ColumnarValue::from( + unary_string_function::(&[a.as_ref()], op, name)?, + )) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function {}", @@ -309,7 +299,7 @@ pub fn concat(args: &[ColumnarValue]) -> Result { }) .collect::(); - Ok(ColumnarValue::Array(Arc::new(result))) + Ok(ColumnarValue::from(result)) } else { // short avenue with only scalars let initial = Some("".to_string()); diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index d9f84e7cb862..06d3739b53b2 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -46,6 +46,10 @@ pub fn coerce( schema: &Schema, signature: &Signature, ) -> Result>> { + if expressions.is_empty() { + return Ok(vec![]); + } + let current_types = expressions .iter() .map(|e| e.data_type(schema)) @@ -68,6 +72,9 @@ pub fn data_types( current_types: &[DataType], signature: &Signature, ) -> Result> { + if current_types.is_empty() { + return Ok(vec![]); + } let valid_types = get_valid_types(signature, current_types)?; if valid_types diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 3194ac7bf794..d8c9032f60e8 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -689,7 +689,7 @@ fn custom_sqrt(args: &[ColumnarValue]) -> Result { .expect("cast failed"); let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect(); - Ok(ColumnarValue::Array(Arc::new(array))) + Ok(ColumnarValue::from(array)) } else { unimplemented!() } @@ -2850,6 +2850,17 @@ async fn test_cast_expressions() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_random_expression() -> Result<()> { + let mut ctx = create_ctx()?; + let sql = format!("SELECT random() r1"); + let actual = execute(&mut ctx, sql.as_str()).await; + let r1 = actual[0][0].parse::().unwrap(); + assert!(0.0 <= r1); + assert!(r1 < 1.0); + Ok(()) +} + #[tokio::test] async fn test_cast_expressions_error() -> Result<()> { // sin(utf8) should error