Skip to content

Commit 1e6ad19

Browse files
author
Agaev Huseyn
committed
Fix functions with Volatility::Volatile and parameters
1 parent f718fe2 commit 1e6ad19

File tree

3 files changed

+226
-5
lines changed

3 files changed

+226
-5
lines changed

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
// under the License.
1717

1818
use std::any::Any;
19+
use std::collections::HashMap;
1920
use std::hash::{DefaultHasher, Hash, Hasher};
2021
use std::sync::Arc;
2122

23+
use arrow::array::as_string_array;
2224
use arrow::compute::kernels::numeric::add;
2325
use arrow_array::builder::BooleanBuilder;
2426
use arrow_array::cast::AsArray;
@@ -483,6 +485,196 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
483485
Ok(())
484486
}
485487

488+
/// Volatile UDF that should be append a different value to each row
489+
struct AddIndexToStringScalarUDF {
490+
name: String,
491+
signature: Signature,
492+
return_type: DataType,
493+
}
494+
495+
impl AddIndexToStringScalarUDF {
496+
fn new() -> Self {
497+
Self {
498+
name: "add_index_to_string".to_string(),
499+
signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile),
500+
return_type: DataType::Utf8,
501+
}
502+
}
503+
}
504+
505+
impl std::fmt::Debug for AddIndexToStringScalarUDF {
506+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
507+
f.debug_struct("ScalarUDF")
508+
.field("name", &self.name)
509+
.field("signature", &self.signature)
510+
.field("fun", &"<FUNC>")
511+
.finish()
512+
}
513+
}
514+
515+
impl ScalarUDFImpl for AddIndexToStringScalarUDF {
516+
fn as_any(&self) -> &dyn Any {
517+
self
518+
}
519+
520+
fn name(&self) -> &str {
521+
&self.name
522+
}
523+
524+
fn signature(&self) -> &Signature {
525+
&self.signature
526+
}
527+
528+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
529+
Ok(self.return_type.clone())
530+
}
531+
532+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
533+
not_impl_err!("index_with_offset function does not accept arguments")
534+
}
535+
536+
fn invoke_batch(
537+
&self,
538+
_args: &[ColumnarValue],
539+
_number_rows: usize,
540+
) -> Result<ColumnarValue> {
541+
let answer = match &_args[0] {
542+
// When called with static arguments, the result is returned as an array.
543+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => {
544+
let mut answer = vec![];
545+
for index in 1..=_number_rows {
546+
// When calling a function with immutable arguments, the result is returned with ")".
547+
// Example: SELECT add_index_to_string('const_value') FROM table;
548+
answer.push(index.to_string() + ") " + value);
549+
}
550+
answer
551+
}
552+
// The result is returned as an array when called with dynamic arguments.
553+
ColumnarValue::Array(array) => {
554+
let string_array = as_string_array(array);
555+
let mut counter = HashMap::<&str, u64>::new();
556+
string_array
557+
.iter()
558+
.map(|value| {
559+
let value = value.expect("Unexpected null");
560+
let index = counter.get(value).unwrap_or(&0) + 1;
561+
counter.insert(value, index);
562+
563+
// When calling a function with mutable arguments, the result is returned with ".".
564+
// Example: SELECT add_index_to_string(table.value) FROM table;
565+
index.to_string() + ". " + value
566+
})
567+
.collect()
568+
}
569+
_ => unimplemented!(),
570+
};
571+
Ok(ColumnarValue::Array(
572+
Arc::new(StringArray::from(answer)) as ArrayRef
573+
))
574+
}
575+
}
576+
577+
#[tokio::test]
578+
async fn volatile_scalar_udf_with_params() -> Result<()> {
579+
{
580+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
581+
582+
let batch = RecordBatch::try_new(
583+
Arc::new(schema.clone()),
584+
vec![Arc::new(StringArray::from(vec![
585+
"test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2",
586+
]))],
587+
)?;
588+
let ctx = SessionContext::new();
589+
590+
ctx.register_batch("t", batch)?;
591+
592+
let get_new_str_udf = AddIndexToStringScalarUDF::new();
593+
594+
ctx.register_udf(ScalarUDF::from(get_new_str_udf));
595+
596+
let result =
597+
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters
598+
.await?;
599+
let expected = [
600+
"+-----------+",
601+
"| str |",
602+
"+-----------+",
603+
"| 1. test_1 |",
604+
"| 2. test_1 |",
605+
"| 3. test_1 |",
606+
"| 1. test_2 |",
607+
"| 2. test_2 |",
608+
"| 4. test_1 |",
609+
"| 3. test_2 |",
610+
"+-----------+",
611+
];
612+
assert_batches_eq!(expected, &result);
613+
614+
let result =
615+
plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters
616+
.await?;
617+
let expected = [
618+
"+---------+",
619+
"| str |",
620+
"+---------+",
621+
"| 1) test |",
622+
"| 2) test |",
623+
"| 3) test |",
624+
"| 4) test |",
625+
"| 5) test |",
626+
"| 6) test |",
627+
"| 7) test |",
628+
"+---------+",
629+
];
630+
assert_batches_eq!(expected, &result);
631+
632+
let result =
633+
plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters
634+
.await?;
635+
let expected = [
636+
"+---------------+",
637+
"| str |",
638+
"+---------------+",
639+
"| 1) test_value |",
640+
"+---------------+",
641+
];
642+
assert_batches_eq!(expected, &result);
643+
}
644+
{
645+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
646+
647+
let batch = RecordBatch::try_new(
648+
Arc::new(schema.clone()),
649+
vec![Arc::new(StringArray::from(vec![
650+
"test_1", "test_1", "test_1",
651+
]))],
652+
)?;
653+
let ctx = SessionContext::new();
654+
655+
ctx.register_batch("t", batch)?;
656+
657+
let get_new_str_udf = AddIndexToStringScalarUDF::new();
658+
659+
ctx.register_udf(ScalarUDF::from(get_new_str_udf));
660+
661+
let result =
662+
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t")
663+
.await?;
664+
let expected = [
665+
"+-----------+", //
666+
"| str |", //
667+
"+-----------+", //
668+
"| 1. test_1 |", //
669+
"| 2. test_1 |", //
670+
"| 3. test_1 |", //
671+
"+-----------+",
672+
];
673+
assert_batches_eq!(expected, &result);
674+
}
675+
Ok(())
676+
}
677+
486678
#[derive(Debug)]
487679
struct CastToI64UDF {
488680
signature: Signature,

datafusion/expr/src/udf.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ impl ScalarUDF {
201201
self.inner.is_nullable(args, schema)
202202
}
203203

204+
/// Invoke the function with `args` and number of rows, returning the appropriate result.
205+
///
206+
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
207+
pub fn invoke_batch(
208+
&self,
209+
args: &[ColumnarValue],
210+
number_rows: usize,
211+
) -> Result<ColumnarValue> {
212+
self.inner.invoke_batch(args, number_rows)
213+
}
214+
204215
/// Invoke the function without `args` but number of rows, returning the appropriate result.
205216
///
206217
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
@@ -467,7 +478,28 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
467478
/// to arrays, which will likely be simpler code, but be slower.
468479
///
469480
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
470-
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>;
481+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
482+
not_impl_err!(
483+
"Function {} does not implement invoke but called",
484+
self.name()
485+
)
486+
}
487+
488+
/// Invoke the function with `args` and the number of rows,
489+
/// returning the appropriate result.
490+
///
491+
/// The function should be used for signatures with [`datafusion_expr_common::signature::Volatility::Volatile`]
492+
/// and with arguments.
493+
fn invoke_batch(
494+
&self,
495+
_args: &[ColumnarValue],
496+
_number_rows: usize,
497+
) -> Result<ColumnarValue> {
498+
match _args.is_empty() {
499+
true => self.invoke_no_args(_number_rows),
500+
false => self.invoke(_args),
501+
}
502+
}
471503

472504
/// Invoke the function without `args`, instead the number of rows are provided,
473505
/// returning the appropriate result.

datafusion/physical-expr/src/scalar_function.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
141141
.collect::<Result<Vec<_>>>()?;
142142

143143
// evaluate the function
144-
let output = match self.args.is_empty() {
145-
true => self.fun.invoke_no_args(batch.num_rows()),
146-
false => self.fun.invoke(&inputs),
147-
}?;
144+
let output = self.fun.invoke_batch(&inputs, batch.num_rows())?;
148145

149146
if let ColumnarValue::Array(array) = &output {
150147
if array.len() != batch.num_rows() {

0 commit comments

Comments
 (0)