Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScalarUDF: Remove supports_zero_argument and avoid creating null array for empty args #10193

Merged
4 changes: 0 additions & 4 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1471,7 +1470,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 3))),
Expand Down Expand Up @@ -1543,7 +1541,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1612,7 +1609,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d_new", 3))),
Expand Down
140 changes: 17 additions & 123 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
// under the License.

use arrow::compute::kernels::numeric::add;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
};
use arrow_schema::DataType::Float64;
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
Expand All @@ -36,9 +33,7 @@ use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
use std::sync::Arc;

/// test that casting happens on udfs.
Expand Down Expand Up @@ -403,123 +398,6 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
Ok(())
}

#[derive(Debug)]
pub struct RandomUDF {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have RandomFunc, they are the same so remove it.

signature: Signature,
}

impl RandomUDF {
pub fn new() -> Self {
Self {
signature: Signature::any(0, Volatility::Volatile),
}
}
}

impl ScalarUDFImpl for RandomUDF {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"random_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len: usize = match &args[0] {
// This udf is always invoked with zero argument so its argument
// is a null array indicating the batch size.
ColumnarValue::Array(array) if array.data_type().is_null() => array.len(),
_ => {
return Err(datafusion::error::DataFusionError::Internal(
"Invalid argument type".to_string(),
))
}
};
let mut rng = thread_rng();
let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len);
let array = Float64Array::from_iter_values(values);
Ok(ColumnarValue::Array(Arc::new(array)))
}
}

/// Ensure that a user defined function with zero argument will be invoked
/// with a null array indicating the batch size.
#[tokio::test]
async fn test_user_defined_functions_zero_argument() -> Result<()> {
let ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![Field::new(
"index",
DataType::UInt8,
false,
)]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))],
)?;

ctx.register_batch("data_table", batch)?;

let random_normal_udf = ScalarUDF::from(RandomUDF::new());
ctx.register_udf(random_normal_udf);

let result = plan_and_collect(
&ctx,
"SELECT random_udf() AS random_udf, random() AS native_random FROM data_table",
)
.await?;

assert_eq!(result.len(), 1);
let batch = &result[0];
let random_udf = batch
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let native_random = batch
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();

assert_eq!(random_udf.len(), native_random.len());

let mut previous = -1.0;
for i in 0..random_udf.len() {
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0);
assert!(random_udf.value(i) != previous);
previous = random_udf.value(i);
}

Ok(())
}

#[tokio::test]
async fn deregister_udf() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is just moved , the test remains

let random_normal_udf = ScalarUDF::from(RandomUDF::new());
let ctx = SessionContext::new();

ctx.register_udf(random_normal_udf.clone());

assert!(ctx.udfs().contains("random_udf"));

ctx.deregister_udf("random_udf");

assert!(!ctx.udfs().contains("random_udf"));

Ok(())
}

#[derive(Debug)]
struct CastToI64UDF {
signature: Signature,
Expand Down Expand Up @@ -615,6 +493,22 @@ async fn test_user_defined_functions_cast_to_i64() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn deregister_udf() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let cast2i64 = ScalarUDF::from(CastToI64UDF::new());
let ctx = SessionContext::new();

ctx.register_udf(cast2i64.clone());

assert!(ctx.udfs().contains("cast_to_i64"));

ctx.deregister_udf("cast_to_i64");

assert!(!ctx.udfs().contains("cast_to_i64"));

Ok(())
}

#[derive(Debug)]
struct TakeUDF {
signature: Signature,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub fn data_types(
);
}
}

let valid_types = get_valid_types(&signature.type_signature, current_types)?;

if valid_types
Expand Down
5 changes: 0 additions & 5 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// The function will be invoked passed with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// # Zero Argument Functions
/// If the function has zero parameters (e.g. `now()`) it will be passed a
/// single element slice which is a a null array to indicate the batch's row
/// count (so the function can know the resulting array size).
///
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please leave a note about what is required to implement Zero Argument Functions? I think the expectation is that the output is a single ColumnarValue::Scalar, rather than an Array

/// # Performance
///
/// For the best performance, the implementations of `invoke` should handle
Expand Down
14 changes: 5 additions & 9 deletions datafusion/functions/src/math/pi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
// under the License.

use std::any::Any;
use std::sync::Arc;

use arrow::array::Float64Array;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Float64;

use datafusion_common::{exec_err, Result};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, FuncMonotonicity, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

Expand Down Expand Up @@ -62,12 +60,10 @@ impl ScalarUDFImpl for PiFunc {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
if !matches!(&args[0], ColumnarValue::Array(_)) {
return exec_err!("Expect pi function to take no param");
}
let array = Float64Array::from_value(std::f64::consts::PI, 1);
Ok(ColumnarValue::Array(Arc::new(array)))
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think signature check is enough, so just ignore args

Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(
std::f64::consts::PI,
))))
}

fn monotonicity(&self) -> Result<Option<FuncMonotonicity>> {
Expand Down
49 changes: 5 additions & 44 deletions datafusion/functions/src/math/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
// under the License.

use std::any::Any;
use std::iter;
use std::sync::Arc;

use arrow::array::Float64Array;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Float64;
use rand::{thread_rng, Rng};

use datafusion_common::{exec_err, Result};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};

Expand Down Expand Up @@ -64,45 +61,9 @@ impl ScalarUDFImpl for RandomFunc {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
random(args)
}
}

/// Random SQL function
fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len: usize = match &args[0] {
ColumnarValue::Array(array) => array.len(),
_ => return exec_err!("Expect random function to take no param"),
};
let mut rng = thread_rng();
let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len);
let array = Float64Array::from_iter_values(values);
Ok(ColumnarValue::Array(Arc::new(array)))
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use arrow::array::NullArray;

use datafusion_common::cast::as_float64_array;
use datafusion_expr::ColumnarValue;

use crate::math::random::random;

#[test]
fn test_random_expression() {
let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))];
let array = random(&args)
.expect("failed to initialize function random")
.into_array(1)
.expect("Failed to convert to array");
let floats =
as_float64_array(&array).expect("failed to initialize function random");

assert_eq!(floats.len(), 1);
assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0);
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea here is that expectation is that rand is invoked once per row rather than once per batch. And the only way it knew how many rows to make is to get a null array in 🤔

For example, when I run datafusion-cli from this PR to call random() the same value is returned for each row:

> create table foo as values (1), (2), (3), (4), (5);
0 row(s) fetched.
Elapsed 0.018 seconds.

> select column1, random() from foo;
+---------+--------------------+
| column1 | random()           |
+---------+--------------------+
| 1       | 0.9594375709000513 |
| 2       | 0.9594375709000513 |
| 3       | 0.9594375709000513 |
| 4       | 0.9594375709000513 |
| 5       | 0.9594375709000513 |
+---------+--------------------+
5 row(s) fetched.
Elapsed 0.012 seconds.

But I expect that each row has a different value for random()

However, since none of the tests failed, clearly we have a gap in test coverage 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Let me think about how to design it, I would prefer something like support_random to specialize random() case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding a invoke_no_args(num_rows: usize) method to the ScalarUDFImpl -- with a default implementation that returns "not implemented" error

That might make it clear what was happening and would provide clear semantics about what to do in this case 🤔

Copy link
Contributor Author

@jayzhan211 jayzhan211 Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see the message then, it is also a good idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we test the uuid function? i think the uuid has the same attribute like random
cc @jayzhan211

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have test like

query II
SELECT octet_length(uuid()), length(uuid())
----
36 36

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have test like

query II
SELECT octet_length(uuid()), length(uuid())
----
36 36

I means the uuid function is also invoked once per row rather than once per batch like random() function mentioned by alamb.

I test the function in the spark

spark-sql> desc test;
col1    int     NULL
Time taken: 0.065 seconds, Fetched 1 row(s)


spark-sql> select * from test;
1
2



spark-sql> select *,uuid() from test;
1       6b04b66c-2e6c-4925-8b18-a9d51d5ed80a
2       3e0be0c2-9ff2-422a-8f3f-5cdb6551264b

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

file a issue for this: #10247

let mut rng = thread_rng();
let val = rng.gen_range(0.0..1.0);
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(val))))
}
}
18 changes: 5 additions & 13 deletions datafusion/functions/src/string/uuid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
// under the License.

use std::any::Any;
use std::iter;
use std::sync::Arc;

use arrow::array::GenericStringArray;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Utf8;
use uuid::Uuid;

use datafusion_common::{exec_err, Result};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

Expand Down Expand Up @@ -60,14 +57,9 @@ impl ScalarUDFImpl for UuidFunc {

/// Prints random (v4) uuid values per row
/// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len: usize = match &args[0] {
ColumnarValue::Array(array) => array.len(),
_ => return exec_err!("Expect uuid function to take no param"),
};

let values = iter::repeat_with(|| Uuid::new_v4().to_string()).take(len);
let array = GenericStringArray::<i32>::from_iter_values(values);
Ok(ColumnarValue::Array(Arc::new(array)))
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
Uuid::new_v4().to_string(),
))))
}
}
1 change: 0 additions & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ pub fn create_physical_expr(
input_phy_exprs.to_vec(),
return_type,
fun.monotonicity()?,
fun.signature().type_signature.supports_zero_argument(),
)))
}

Expand Down
Loading
Loading