Skip to content

Commit 2921595

Browse files
agscppAgaev Huseyn
andcommitted
Fix functions with Volatility::Volatile and parameters (apache#13001)
Co-authored-by: Agaev Huseyn <h.agaev@vkteam.ru>
1 parent 4cae813 commit 2921595

File tree

3 files changed

+215
-6
lines changed

3 files changed

+215
-6
lines changed

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 184 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616
// under the License.
1717

1818
use std::any::Any;
19+
use std::collections::HashMap;
1920
use std::sync::Arc;
2021

22+
use arrow::array::as_string_array;
2123
use arrow::compute::kernels::numeric::add;
22-
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch};
24+
use arrow_array::{
25+
ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray,
26+
};
2327
use arrow_schema::{DataType, Field, Schema};
2428
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
2529
use datafusion::prelude::*;
@@ -476,6 +480,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
476480
Ok(())
477481
}
478482

483+
/// Volatile UDF that should append a different value to each row
484+
#[derive(Debug)]
485+
struct AddIndexToStringVolatileScalarUDF {
486+
name: String,
487+
signature: Signature,
488+
return_type: DataType,
489+
}
490+
491+
impl AddIndexToStringVolatileScalarUDF {
492+
fn new() -> Self {
493+
Self {
494+
name: "add_index_to_string".to_string(),
495+
signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile),
496+
return_type: DataType::Utf8,
497+
}
498+
}
499+
}
500+
501+
impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
502+
fn as_any(&self) -> &dyn Any {
503+
self
504+
}
505+
506+
fn name(&self) -> &str {
507+
&self.name
508+
}
509+
510+
fn signature(&self) -> &Signature {
511+
&self.signature
512+
}
513+
514+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
515+
Ok(self.return_type.clone())
516+
}
517+
518+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
519+
not_impl_err!("index_with_offset function does not accept arguments")
520+
}
521+
522+
fn invoke_batch(
523+
&self,
524+
args: &[ColumnarValue],
525+
number_rows: usize,
526+
) -> Result<ColumnarValue> {
527+
let answer = match &args[0] {
528+
// When called with static arguments, the result is returned as an array.
529+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => {
530+
let mut answer = vec![];
531+
for index in 1..=number_rows {
532+
// When calling a function with immutable arguments, the result is returned with ")".
533+
// Example: SELECT add_index_to_string('const_value') FROM table;
534+
answer.push(index.to_string() + ") " + value);
535+
}
536+
answer
537+
}
538+
// The result is returned as an array when called with dynamic arguments.
539+
ColumnarValue::Array(array) => {
540+
let string_array = as_string_array(array);
541+
let mut counter = HashMap::<&str, u64>::new();
542+
string_array
543+
.iter()
544+
.map(|value| {
545+
let value = value.expect("Unexpected null");
546+
let index = counter.get(value).unwrap_or(&0) + 1;
547+
counter.insert(value, index);
548+
549+
// When calling a function with mutable arguments, the result is returned with ".".
550+
// Example: SELECT add_index_to_string(table.value) FROM table;
551+
index.to_string() + ". " + value
552+
})
553+
.collect()
554+
}
555+
_ => unimplemented!(),
556+
};
557+
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
558+
}
559+
}
560+
561+
#[tokio::test]
562+
async fn volatile_scalar_udf_with_params() -> Result<()> {
563+
{
564+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
565+
566+
let batch = RecordBatch::try_new(
567+
Arc::new(schema.clone()),
568+
vec![Arc::new(StringArray::from(vec![
569+
"test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2",
570+
]))],
571+
)?;
572+
let ctx = SessionContext::new();
573+
574+
ctx.register_batch("t", batch)?;
575+
576+
let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new();
577+
578+
ctx.register_udf(ScalarUDF::from(get_new_str_udf));
579+
580+
let result =
581+
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters
582+
.await?;
583+
let expected = [
584+
"+-----------+",
585+
"| str |",
586+
"+-----------+",
587+
"| 1. test_1 |",
588+
"| 2. test_1 |",
589+
"| 3. test_1 |",
590+
"| 1. test_2 |",
591+
"| 2. test_2 |",
592+
"| 4. test_1 |",
593+
"| 3. test_2 |",
594+
"+-----------+",
595+
];
596+
assert_batches_eq!(expected, &result);
597+
598+
let result =
599+
plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters
600+
.await?;
601+
let expected = [
602+
"+---------+",
603+
"| str |",
604+
"+---------+",
605+
"| 1) test |",
606+
"| 2) test |",
607+
"| 3) test |",
608+
"| 4) test |",
609+
"| 5) test |",
610+
"| 6) test |",
611+
"| 7) test |",
612+
"+---------+",
613+
];
614+
assert_batches_eq!(expected, &result);
615+
616+
let result =
617+
plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters
618+
.await?;
619+
let expected = [
620+
"+---------------+",
621+
"| str |",
622+
"+---------------+",
623+
"| 1) test_value |",
624+
"+---------------+",
625+
];
626+
assert_batches_eq!(expected, &result);
627+
}
628+
{
629+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
630+
631+
let batch = RecordBatch::try_new(
632+
Arc::new(schema.clone()),
633+
vec![Arc::new(StringArray::from(vec![
634+
"test_1", "test_1", "test_1",
635+
]))],
636+
)?;
637+
let ctx = SessionContext::new();
638+
639+
ctx.register_batch("t", batch)?;
640+
641+
let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new();
642+
643+
ctx.register_udf(ScalarUDF::from(get_new_str_udf));
644+
645+
let result =
646+
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t")
647+
.await?;
648+
let expected = [
649+
"+-----------+", //
650+
"| str |", //
651+
"+-----------+", //
652+
"| 1. test_1 |", //
653+
"| 2. test_1 |", //
654+
"| 3. test_1 |", //
655+
"+-----------+",
656+
];
657+
assert_batches_eq!(expected, &result);
658+
}
659+
Ok(())
660+
}
661+
479662
#[derive(Debug)]
480663
struct CastToI64UDF {
481664
signature: Signature,

datafusion/expr/src/udf.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,17 @@ impl ScalarUDF {
189189
self.inner.invoke(args)
190190
}
191191

192+
/// Invoke the function with `args` and number of rows, returning the appropriate result.
193+
///
194+
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
195+
pub fn invoke_batch(
196+
&self,
197+
args: &[ColumnarValue],
198+
number_rows: usize,
199+
) -> Result<ColumnarValue> {
200+
self.inner.invoke_batch(args, number_rows)
201+
}
202+
192203
/// Invoke the function without `args` but number of rows, returning the appropriate result.
193204
///
194205
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
@@ -408,7 +419,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
408419
/// to arrays, which will likely be simpler code, but be slower.
409420
///
410421
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
411-
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>;
422+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
423+
not_impl_err!(
424+
"Function {} does not implement invoke but called",
425+
self.name()
426+
)
427+
}
428+
429+
/// Invoke the function with `args` and the number of rows,
430+
/// returning the appropriate result.
431+
fn invoke_batch(
432+
&self,
433+
args: &[ColumnarValue],
434+
number_rows: usize,
435+
) -> Result<ColumnarValue> {
436+
match args.is_empty() {
437+
true => self.invoke_no_args(number_rows),
438+
false => self.invoke(args),
439+
}
440+
}
412441

413442
/// Invoke the function without `args`, instead the number of rows are provided,
414443
/// returning the appropriate result.

datafusion/physical-expr/src/scalar_function.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
129129
.collect::<Result<Vec<_>>>()?;
130130

131131
// evaluate the function
132-
let output = match self.args.is_empty() {
133-
true => self.fun.invoke_no_args(batch.num_rows()),
134-
false => self.fun.invoke(&inputs),
135-
}?;
132+
let output = self.fun.invoke_batch(&inputs, batch.num_rows())?;
136133

137134
if let ColumnarValue::Array(array) = &output {
138135
if array.len() != batch.num_rows() {

0 commit comments

Comments
 (0)