-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ScalarUDF: Remove supports_zero_argument
and avoid creating null array for empty args
#10193
Changes from 7 commits
eabfe68
7b529d9
03ec8b5
36e685e
7b04c0b
7c10382
864d197
5b51fb7
88d2a33
7c81776
bd4c65b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,10 +16,7 @@ | |
// under the License. | ||
|
||
use arrow::compute::kernels::numeric::add; | ||
use arrow_array::{ | ||
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, | ||
}; | ||
use arrow_schema::DataType::Float64; | ||
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch}; | ||
use arrow_schema::{DataType, Field, Schema}; | ||
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; | ||
use datafusion::prelude::*; | ||
|
@@ -36,9 +33,7 @@ use datafusion_expr::{ | |
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable, | ||
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, | ||
}; | ||
use rand::{thread_rng, Rng}; | ||
use std::any::Any; | ||
use std::iter; | ||
use std::sync::Arc; | ||
|
||
/// test that casting happens on udfs. | ||
|
@@ -403,123 +398,6 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { | |
Ok(()) | ||
} | ||
|
||
#[derive(Debug)] | ||
pub struct RandomUDF { | ||
signature: Signature, | ||
} | ||
|
||
impl RandomUDF { | ||
pub fn new() -> Self { | ||
Self { | ||
signature: Signature::any(0, Volatility::Volatile), | ||
} | ||
} | ||
} | ||
|
||
impl ScalarUDFImpl for RandomUDF { | ||
fn as_any(&self) -> &dyn std::any::Any { | ||
self | ||
} | ||
|
||
fn name(&self) -> &str { | ||
"random_udf" | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(Float64) | ||
} | ||
|
||
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
let len: usize = match &args[0] { | ||
// This udf is always invoked with zero argument so its argument | ||
// is a null array indicating the batch size. | ||
ColumnarValue::Array(array) if array.data_type().is_null() => array.len(), | ||
_ => { | ||
return Err(datafusion::error::DataFusionError::Internal( | ||
"Invalid argument type".to_string(), | ||
)) | ||
} | ||
}; | ||
let mut rng = thread_rng(); | ||
let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len); | ||
let array = Float64Array::from_iter_values(values); | ||
Ok(ColumnarValue::Array(Arc::new(array))) | ||
} | ||
} | ||
|
||
/// Ensure that a user defined function with zero argument will be invoked | ||
/// with a null array indicating the batch size. | ||
#[tokio::test] | ||
async fn test_user_defined_functions_zero_argument() -> Result<()> { | ||
let ctx = SessionContext::new(); | ||
|
||
let schema = Arc::new(Schema::new(vec![Field::new( | ||
"index", | ||
DataType::UInt8, | ||
false, | ||
)])); | ||
|
||
let batch = RecordBatch::try_new( | ||
schema, | ||
vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))], | ||
)?; | ||
|
||
ctx.register_batch("data_table", batch)?; | ||
|
||
let random_normal_udf = ScalarUDF::from(RandomUDF::new()); | ||
ctx.register_udf(random_normal_udf); | ||
|
||
let result = plan_and_collect( | ||
&ctx, | ||
"SELECT random_udf() AS random_udf, random() AS native_random FROM data_table", | ||
) | ||
.await?; | ||
|
||
assert_eq!(result.len(), 1); | ||
let batch = &result[0]; | ||
let random_udf = batch | ||
.column(0) | ||
.as_any() | ||
.downcast_ref::<Float64Array>() | ||
.unwrap(); | ||
let native_random = batch | ||
.column(1) | ||
.as_any() | ||
.downcast_ref::<Float64Array>() | ||
.unwrap(); | ||
|
||
assert_eq!(random_udf.len(), native_random.len()); | ||
|
||
let mut previous = -1.0; | ||
for i in 0..random_udf.len() { | ||
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0); | ||
assert!(random_udf.value(i) != previous); | ||
previous = random_udf.value(i); | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn deregister_udf() -> Result<()> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is just moved , the test remains |
||
let random_normal_udf = ScalarUDF::from(RandomUDF::new()); | ||
let ctx = SessionContext::new(); | ||
|
||
ctx.register_udf(random_normal_udf.clone()); | ||
|
||
assert!(ctx.udfs().contains("random_udf")); | ||
|
||
ctx.deregister_udf("random_udf"); | ||
|
||
assert!(!ctx.udfs().contains("random_udf")); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[derive(Debug)] | ||
struct CastToI64UDF { | ||
signature: Signature, | ||
|
@@ -615,6 +493,22 @@ async fn test_user_defined_functions_cast_to_i64() -> Result<()> { | |
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn deregister_udf() -> Result<()> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
let cast2i64 = ScalarUDF::from(CastToI64UDF::new()); | ||
let ctx = SessionContext::new(); | ||
|
||
ctx.register_udf(cast2i64.clone()); | ||
|
||
assert!(ctx.udfs().contains("cast_to_i64")); | ||
|
||
ctx.deregister_udf("cast_to_i64"); | ||
|
||
assert!(!ctx.udfs().contains("cast_to_i64")); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[derive(Debug)] | ||
struct TakeUDF { | ||
signature: Signature, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -322,11 +322,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { | |
/// The function will be invoked passed with the slice of [`ColumnarValue`] | ||
/// (either scalar or array). | ||
/// | ||
/// # Zero Argument Functions | ||
/// If the function has zero parameters (e.g. `now()`) it will be passed a | ||
/// single element slice which is a a null array to indicate the batch's row | ||
/// count (so the function can know the resulting array size). | ||
/// | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we please leave a note about what is required to implement |
||
/// # Performance | ||
/// | ||
/// For the best performance, the implementations of `invoke` should handle | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,11 @@ | |
// under the License. | ||
|
||
use std::any::Any; | ||
use std::sync::Arc; | ||
|
||
use arrow::array::Float64Array; | ||
use arrow::datatypes::DataType; | ||
use arrow::datatypes::DataType::Float64; | ||
|
||
use datafusion_common::{exec_err, Result}; | ||
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_expr::{ColumnarValue, FuncMonotonicity, Volatility}; | ||
use datafusion_expr::{ScalarUDFImpl, Signature}; | ||
|
||
|
@@ -62,12 +60,10 @@ impl ScalarUDFImpl for PiFunc { | |
Ok(Float64) | ||
} | ||
|
||
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
if !matches!(&args[0], ColumnarValue::Array(_)) { | ||
return exec_err!("Expect pi function to take no param"); | ||
} | ||
let array = Float64Array::from_value(std::f64::consts::PI, 1); | ||
Ok(ColumnarValue::Array(Arc::new(array))) | ||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think signature check is enough, so just ignore args |
||
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( | ||
std::f64::consts::PI, | ||
)))) | ||
} | ||
|
||
fn monotonicity(&self) -> Result<Option<FuncMonotonicity>> { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,15 +16,12 @@ | |
// under the License. | ||
|
||
use std::any::Any; | ||
use std::iter; | ||
use std::sync::Arc; | ||
|
||
use arrow::array::Float64Array; | ||
use arrow::datatypes::DataType; | ||
use arrow::datatypes::DataType::Float64; | ||
use rand::{thread_rng, Rng}; | ||
|
||
use datafusion_common::{exec_err, Result}; | ||
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_expr::ColumnarValue; | ||
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; | ||
|
||
|
@@ -64,45 +61,9 @@ impl ScalarUDFImpl for RandomFunc { | |
Ok(Float64) | ||
} | ||
|
||
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
random(args) | ||
} | ||
} | ||
|
||
/// Random SQL function | ||
fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
let len: usize = match &args[0] { | ||
ColumnarValue::Array(array) => array.len(), | ||
_ => return exec_err!("Expect random function to take no param"), | ||
}; | ||
let mut rng = thread_rng(); | ||
let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len); | ||
let array = Float64Array::from_iter_values(values); | ||
Ok(ColumnarValue::Array(Arc::new(array))) | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use std::sync::Arc; | ||
|
||
use arrow::array::NullArray; | ||
|
||
use datafusion_common::cast::as_float64_array; | ||
use datafusion_expr::ColumnarValue; | ||
|
||
use crate::math::random::random; | ||
|
||
#[test] | ||
fn test_random_expression() { | ||
let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; | ||
let array = random(&args) | ||
.expect("failed to initialize function random") | ||
.into_array(1) | ||
.expect("Failed to convert to array"); | ||
let floats = | ||
as_float64_array(&array).expect("failed to initialize function random"); | ||
|
||
assert_eq!(floats.len(), 1); | ||
assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); | ||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the idea here is that expectation is that For example, when I run > create table foo as values (1), (2), (3), (4), (5);
0 row(s) fetched.
Elapsed 0.018 seconds.
> select column1, random() from foo;
+---------+--------------------+
| column1 | random() |
+---------+--------------------+
| 1 | 0.9594375709000513 |
| 2 | 0.9594375709000513 |
| 3 | 0.9594375709000513 |
| 4 | 0.9594375709000513 |
| 5 | 0.9594375709000513 |
+---------+--------------------+
5 row(s) fetched.
Elapsed 0.012 seconds. But I expect that each row has a different value for However, since none of the tests failed, clearly we have a gap in test coverage 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch! Let me think about how to design it, I would prefer something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about adding a That might make it clear what was happening and would provide clear semantics about what to do in this case 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't see the message then, it is also a good idea. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we test the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have test like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I means the I test the function in the spark
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. file a issue for this: #10247 |
||
let mut rng = thread_rng(); | ||
let val = rng.gen_range(0.0..1.0); | ||
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(val)))) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have
RandomFunc
, they are the same so remove it.